Skip to content

Commit ef6366f

Browse files
authored
openai native web search tool enabled (#2410)
1 parent fd0e75a commit ef6366f

File tree

5 files changed

+105
-72
lines changed

5 files changed

+105
-72
lines changed

pkg/aiusechat/openai/openai-backend.go

Lines changed: 43 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -122,10 +122,11 @@ func (m *OpenAIChatMessage) GetUsage() *uctypes.AIUsage {
122122
return nil
123123
}
124124
return &uctypes.AIUsage{
125-
APIType: "openai",
126-
Model: m.Usage.Model,
127-
InputTokens: m.Usage.InputTokens,
128-
OutputTokens: m.Usage.OutputTokens,
125+
APIType: "openai",
126+
Model: m.Usage.Model,
127+
InputTokens: m.Usage.InputTokens,
128+
OutputTokens: m.Usage.OutputTokens,
129+
NativeWebSearchCount: m.Usage.NativeWebSearchCount,
129130
}
130131
}
131132

@@ -281,12 +282,13 @@ type openaiTextFormat struct {
281282
}
282283

283284
type OpenAIUsage struct {
284-
InputTokens int `json:"input_tokens,omitempty"`
285-
OutputTokens int `json:"output_tokens,omitempty"`
286-
TotalTokens int `json:"total_tokens,omitempty"`
287-
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
288-
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
289-
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
285+
InputTokens int `json:"input_tokens,omitempty"`
286+
OutputTokens int `json:"output_tokens,omitempty"`
287+
TotalTokens int `json:"total_tokens,omitempty"`
288+
InputTokensDetails *openaiInputTokensDetails `json:"input_tokens_details,omitempty"`
289+
OutputTokensDetails *openaiOutputTokensDetails `json:"output_tokens_details,omitempty"`
290+
Model string `json:"model,omitempty"` // internal field (not from OpenAI API)
291+
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"` // internal field (not from OpenAI API)
290292
}
291293

292294
type openaiInputTokensDetails struct {
@@ -323,12 +325,13 @@ type openaiBlockState struct {
323325
}
324326

325327
type openaiStreamingState struct {
326-
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
327-
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
328-
msgID string
329-
model string
330-
stepStarted bool
331-
chatOpts uctypes.WaveChatOpts
328+
blockMap map[string]*openaiBlockState // Use item_id as key for UI streaming
329+
toolUseData map[string]*uctypes.UIMessageDataToolUse // Use toolCallId as key
330+
msgID string
331+
model string
332+
stepStarted bool
333+
chatOpts uctypes.WaveChatOpts
334+
webSearchCount int
332335
}
333336

334337
// ---------- Public entrypoint ----------
@@ -759,7 +762,7 @@ func handleOpenAIEvent(
759762
}
760763

761764
// Extract partial message if available
762-
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
765+
finalMessages, _ := extractMessageAndToolsFromResponse(ev.Response, state)
763766

764767
_ = sse.AiMsgError(errorMsg)
765768
return &uctypes.WaveStopReason{
@@ -772,7 +775,7 @@ func handleOpenAIEvent(
772775
}
773776

774777
// Extract the final message and tool calls from the response output
775-
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state.toolUseData)
778+
finalMessages, toolCalls := extractMessageAndToolsFromResponse(ev.Response, state)
776779

777780
stopKind := uctypes.StopKindDone
778781
if len(toolCalls) > 0 {
@@ -820,6 +823,19 @@ func handleOpenAIEvent(
820823
}
821824
return nil, nil
822825

826+
case "response.web_search_call.in_progress":
827+
return nil, nil
828+
829+
case "response.web_search_call.searching":
830+
return nil, nil
831+
832+
case "response.web_search_call.completed":
833+
state.webSearchCount++
834+
return nil, nil
835+
836+
case "response.output_text.annotation.added":
837+
return nil, nil
838+
823839
default:
824840
// log unknown events for debugging
825841
log.Printf("OpenAI: unknown event: %s, data: %s", eventName, data)
@@ -857,9 +873,8 @@ func createToolUseData(toolCallID, toolName string, toolDef *uctypes.ToolDefinit
857873
return toolUseData
858874
}
859875

860-
861876
// extractMessageAndToolsFromResponse extracts the final OpenAI message and tool calls from the completed response
862-
func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[string]*uctypes.UIMessageDataToolUse) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
877+
func extractMessageAndToolsFromResponse(resp openaiResponse, state *openaiStreamingState) ([]*OpenAIChatMessage, []uctypes.WaveToolCall) {
863878
var messageContent []OpenAIMessageContent
864879
var toolCalls []uctypes.WaveToolCall
865880
var messages []*OpenAIChatMessage
@@ -893,7 +908,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
893908
}
894909

895910
// Attach UIToolUseData if available
896-
if data, ok := toolUseData[outputItem.CallId]; ok {
911+
if data, ok := state.toolUseData[outputItem.CallId]; ok {
897912
toolCall.ToolUseData = data
898913
} else {
899914
log.Printf("AI no data-tooluse for %s (callid: %s)\n", outputItem.Id, outputItem.CallId)
@@ -907,7 +922,7 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
907922
argsStr = outputItem.Arguments
908923
}
909924
var toolUseDataPtr *uctypes.UIMessageDataToolUse
910-
if data, ok := toolUseData[outputItem.CallId]; ok {
925+
if data, ok := state.toolUseData[outputItem.CallId]; ok {
911926
toolUseDataPtr = data
912927
}
913928
functionCallMsg := &OpenAIChatMessage{
@@ -925,17 +940,20 @@ func extractMessageAndToolsFromResponse(resp openaiResponse, toolUseData map[str
925940
}
926941

927942
// Create OpenAIChatMessage with assistant message (first in slice)
928-
if resp.Usage != nil {
943+
usage := resp.Usage
944+
if usage != nil {
929945
resp.Usage.Model = resp.Model
946+
if state.webSearchCount > 0 {
947+
usage.NativeWebSearchCount = state.webSearchCount
948+
}
930949
}
931-
932950
assistantMessage := &OpenAIChatMessage{
933951
MessageId: uuid.New().String(),
934952
Message: &OpenAIMessage{
935953
Role: "assistant",
936954
Content: messageContent,
937955
},
938-
Usage: resp.Usage,
956+
Usage: usage,
939957
}
940958

941959
// Return assistant message first, followed by function call messages

pkg/aiusechat/openai/openai-convertmessage.go

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -75,11 +75,11 @@ type OpenAIRequest struct {
7575
}
7676

7777
type OpenAIRequestTool struct {
78-
Name string `json:"name"`
79-
Description string `json:"description,omitempty"`
80-
Parameters any `json:"parameters"`
81-
Strict bool `json:"strict"`
8278
Type string `json:"type"`
79+
Name string `json:"name,omitempty"`
80+
Description string `json:"description,omitempty"`
81+
Parameters any `json:"parameters,omitempty"`
82+
Strict bool `json:"strict,omitempty"`
8383
}
8484

8585
// ConvertToolDefinitionToOpenAI converts a generic ToolDefinition to OpenAI format
@@ -113,13 +113,13 @@ func debugPrintReq(req *OpenAIRequest, endpoint string) {
113113
// buildOpenAIHTTPRequest creates a complete HTTP request for the OpenAI API
114114
func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.WaveChatOpts, cont *uctypes.WaveContinueResponse) (*http.Request, error) {
115115
opts := chatOpts.Config
116-
116+
117117
// If continuing from premium rate limit, downgrade to default model and low thinking
118118
if cont != nil && cont.ContinueFromKind == uctypes.StopKindPremiumRateLimit {
119119
opts.Model = uctypes.DefaultOpenAIModel
120120
opts.ThinkingLevel = uctypes.ThinkingLevelLow
121121
}
122-
122+
123123
if opts.Model == "" {
124124
return nil, errors.New("opts.model is required")
125125
}
@@ -183,6 +183,14 @@ func buildOpenAIHTTPRequest(ctx context.Context, inputs []any, chatOpts uctypes.
183183
reqBody.Tools = append(reqBody.Tools, convertedTool)
184184
}
185185

186+
// Add native web search tool if enabled
187+
if chatOpts.AllowNativeWebSearch {
188+
webSearchTool := OpenAIRequestTool{
189+
Type: "web_search",
190+
}
191+
reqBody.Tools = append(reqBody.Tools, webSearchTool)
192+
}
193+
186194
// Set reasoning based on thinking level
187195
if opts.ThinkingLevel != "" {
188196
reqBody.Reasoning = &ReasoningType{

pkg/aiusechat/uctypes/usechat-types.go

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -222,10 +222,11 @@ type AIChat struct {
222222
}
223223

224224
type AIUsage struct {
225-
APIType string `json:"apitype"`
226-
Model string `json:"model"`
227-
InputTokens int `json:"inputtokens,omitempty"`
228-
OutputTokens int `json:"outputtokens,omitempty"`
225+
APIType string `json:"apitype"`
226+
Model string `json:"model"`
227+
InputTokens int `json:"inputtokens,omitempty"`
228+
OutputTokens int `json:"outputtokens,omitempty"`
229+
NativeWebSearchCount int `json:"nativewebsearchcount,omitempty"`
229230
}
230231

231232
type AIMetrics struct {
@@ -424,6 +425,7 @@ type WaveChatOpts struct {
424425
TabStateGenerator func() (string, []ToolDefinition, error)
425426
WidgetAccess bool
426427
RegisterToolApproval func(string)
428+
AllowNativeWebSearch bool
427429

428430
// emphemeral to the step
429431
TabState string

pkg/aiusechat/usechat.go

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -191,6 +191,7 @@ func getUsage(msgs []uctypes.GenAIMessage) uctypes.AIUsage {
191191
} else {
192192
rtn.InputTokens += usage.InputTokens
193193
rtn.OutputTokens += usage.OutputTokens
194+
rtn.NativeWebSearchCount += usage.NativeWebSearchCount
194195
}
195196
}
196197
}
@@ -369,9 +370,10 @@ func RunAIChat(ctx context.Context, sseHandler *sse.SSEHandlerCh, chatOpts uctyp
369370
}
370371
if len(rtnMessage) > 0 {
371372
usage := getUsage(rtnMessage)
372-
log.Printf("usage: input=%d output=%d\n", usage.InputTokens, usage.OutputTokens)
373+
log.Printf("usage: input=%d output=%d websearch=%d\n", usage.InputTokens, usage.OutputTokens, usage.NativeWebSearchCount)
373374
metrics.Usage.InputTokens += usage.InputTokens
374375
metrics.Usage.OutputTokens += usage.OutputTokens
376+
metrics.Usage.NativeWebSearchCount += usage.NativeWebSearchCount
375377
if usage.Model != "" && metrics.Usage.Model != usage.Model {
376378
metrics.Usage.Model = "mixed"
377379
}
@@ -526,24 +528,25 @@ func WaveAIPostMessageWrap(ctx context.Context, sseHandler *sse.SSEHandlerCh, me
526528

527529
func sendAIMetricsTelemetry(ctx context.Context, metrics *uctypes.AIMetrics) {
528530
event := telemetrydata.MakeTEvent("waveai:post", telemetrydata.TEventProps{
529-
WaveAIAPIType: metrics.Usage.APIType,
530-
WaveAIModel: metrics.Usage.Model,
531-
WaveAIInputTokens: metrics.Usage.InputTokens,
532-
WaveAIOutputTokens: metrics.Usage.OutputTokens,
533-
WaveAIRequestCount: metrics.RequestCount,
534-
WaveAIToolUseCount: metrics.ToolUseCount,
535-
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
536-
WaveAIToolDetail: metrics.ToolDetail,
537-
WaveAIPremiumReq: metrics.PremiumReqCount,
538-
WaveAIProxyReq: metrics.ProxyReqCount,
539-
WaveAIHadError: metrics.HadError,
540-
WaveAIImageCount: metrics.ImageCount,
541-
WaveAIPDFCount: metrics.PDFCount,
542-
WaveAITextDocCount: metrics.TextDocCount,
543-
WaveAITextLen: metrics.TextLen,
544-
WaveAIFirstByteMs: metrics.FirstByteLatency,
545-
WaveAIRequestDurMs: metrics.RequestDuration,
546-
WaveAIWidgetAccess: metrics.WidgetAccess,
531+
WaveAIAPIType: metrics.Usage.APIType,
532+
WaveAIModel: metrics.Usage.Model,
533+
WaveAIInputTokens: metrics.Usage.InputTokens,
534+
WaveAIOutputTokens: metrics.Usage.OutputTokens,
535+
WaveAINativeWebSearchCount: metrics.Usage.NativeWebSearchCount,
536+
WaveAIRequestCount: metrics.RequestCount,
537+
WaveAIToolUseCount: metrics.ToolUseCount,
538+
WaveAIToolUseErrorCount: metrics.ToolUseErrorCount,
539+
WaveAIToolDetail: metrics.ToolDetail,
540+
WaveAIPremiumReq: metrics.PremiumReqCount,
541+
WaveAIProxyReq: metrics.ProxyReqCount,
542+
WaveAIHadError: metrics.HadError,
543+
WaveAIImageCount: metrics.ImageCount,
544+
WaveAIPDFCount: metrics.PDFCount,
545+
WaveAITextDocCount: metrics.TextDocCount,
546+
WaveAITextLen: metrics.TextLen,
547+
WaveAIFirstByteMs: metrics.FirstByteLatency,
548+
WaveAIRequestDurMs: metrics.RequestDuration,
549+
WaveAIWidgetAccess: metrics.WidgetAccess,
547550
})
548551
_ = telemetry.RecordTEvent(ctx, event)
549552
}
@@ -602,6 +605,7 @@ func WaveAIPostMessageHandler(w http.ResponseWriter, r *http.Request) {
602605
Config: *aiOpts,
603606
WidgetAccess: req.WidgetAccess,
604607
RegisterToolApproval: RegisterToolApproval,
608+
AllowNativeWebSearch: true,
605609
}
606610
if chatOpts.Config.APIType == APIType_OpenAI {
607611
chatOpts.SystemPrompt = []string{SystemPromptText_OpenAI}

pkg/telemetry/telemetrydata/telemetrydata.go

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -101,24 +101,25 @@ type TEventProps struct {
101101
CountWSLConn int `json:"count:wslconn,omitempty"`
102102
CountViews map[string]int `json:"count:views,omitempty"`
103103

104-
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
105-
WaveAIModel string `json:"waveai:model,omitempty"`
106-
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
107-
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
108-
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
109-
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
110-
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
111-
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
112-
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
113-
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
114-
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
115-
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
116-
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
117-
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
118-
WaveAITextLen int `json:"waveai:textlen,omitempty"`
119-
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
120-
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
121-
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
104+
WaveAIAPIType string `json:"waveai:apitype,omitempty"`
105+
WaveAIModel string `json:"waveai:model,omitempty"`
106+
WaveAIInputTokens int `json:"waveai:inputtokens,omitempty"`
107+
WaveAIOutputTokens int `json:"waveai:outputtokens,omitempty"`
108+
WaveAINativeWebSearchCount int `json:"waveai:nativewebsearchcount,omitempty"`
109+
WaveAIRequestCount int `json:"waveai:requestcount,omitempty"`
110+
WaveAIToolUseCount int `json:"waveai:toolusecount,omitempty"`
111+
WaveAIToolUseErrorCount int `json:"waveai:tooluseerrorcount,omitempty"`
112+
WaveAIToolDetail map[string]int `json:"waveai:tooldetail,omitempty"`
113+
WaveAIPremiumReq int `json:"waveai:premiumreq,omitempty"`
114+
WaveAIProxyReq int `json:"waveai:proxyreq,omitempty"`
115+
WaveAIHadError bool `json:"waveai:haderror,omitempty"`
116+
WaveAIImageCount int `json:"waveai:imagecount,omitempty"`
117+
WaveAIPDFCount int `json:"waveai:pdfcount,omitempty"`
118+
WaveAITextDocCount int `json:"waveai:textdoccount,omitempty"`
119+
WaveAITextLen int `json:"waveai:textlen,omitempty"`
120+
WaveAIFirstByteMs int `json:"waveai:firstbytems,omitempty"` // ms
121+
WaveAIRequestDurMs int `json:"waveai:requestdurms,omitempty"` // ms
122+
WaveAIWidgetAccess bool `json:"waveai:widgetaccess,omitempty"`
122123

123124
UserSet *TEventUserProps `json:"$set,omitempty"`
124125
UserSetOnce *TEventUserProps `json:"$set_once,omitempty"`

0 commit comments

Comments
 (0)