11import {
22 generateErrorResponse ,
3- generateInvalidProviderResponseError
3+ generateInvalidProviderResponseError ,
4+ getMimeType
45} from '../utils' ;
56import { GOOGLE } from '@/dev/data/models' ;
7+ import type { ToolCall , ToolChoice } from 'types/pipe' ;
68import type {
79 ChatCompletionResponse ,
810 ContentType ,
911 ErrorResponse ,
12+ MessageRole ,
1013 ModelParams ,
1114 ProviderConfig ,
1215 ProviderMessage
@@ -32,6 +35,76 @@ const transformGenerationConfig = (params: ModelParams) => {
3235 return generationConfig ;
3336} ;
3437
38+ export type GoogleMessageRole = 'user' | 'model' | 'function' ;
39+
40+ interface GoogleFunctionCallMessagePart {
41+ functionCall : GoogleGenerateFunctionCall ;
42+ }
43+
44+ interface GoogleFunctionResponseMessagePart {
45+ functionResponse : {
46+ name : string ;
47+ response : {
48+ name ?: string ;
49+ content : string ;
50+ } ;
51+ } ;
52+ }
53+
54+ type GoogleMessagePart =
55+ | GoogleFunctionCallMessagePart
56+ | GoogleFunctionResponseMessagePart
57+ | { text : string } ;
58+
59+ export interface GoogleMessage {
60+ role : GoogleMessageRole ;
61+ parts : GoogleMessagePart [ ] ;
62+ }
63+
64+ export interface GoogleToolConfig {
65+ function_calling_config : {
66+ mode : GoogleToolChoiceType | undefined ;
67+ allowed_function_names ?: string [ ] ;
68+ } ;
69+ }
70+
71+ export const transformOpenAIRoleToGoogleRole = (
72+ role : MessageRole
73+ ) : GoogleMessageRole => {
74+ switch ( role ) {
75+ case 'assistant' :
76+ return 'model' ;
77+ case 'tool' :
78+ return 'function' ;
79+ // Not all gemini models support system role
80+ case 'system' :
81+ return 'user' ;
82+ // user is the default role
83+ default :
84+ return role ;
85+ }
86+ } ;
87+
88+ type GoogleToolChoiceType = 'AUTO' | 'ANY' | 'NONE' ;
89+
90+ export const transformToolChoiceForGemini = (
91+ tool_choice : ToolChoice
92+ ) : GoogleToolChoiceType | undefined => {
93+ if ( typeof tool_choice === 'object' && tool_choice . type === 'function' )
94+ return 'ANY' ;
95+ if ( typeof tool_choice === 'string' ) {
96+ switch ( tool_choice ) {
97+ case 'auto' :
98+ return 'AUTO' ;
99+ case 'none' :
100+ return 'NONE' ;
101+ case 'required' :
102+ return 'ANY' ;
103+ }
104+ }
105+ return undefined ;
106+ } ;
107+
35108export const GoogleChatCompleteConfig : ProviderConfig = {
36109 model : {
37110 param : 'model' ,
@@ -42,36 +115,100 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
42115 param : 'contents' ,
43116 default : '' ,
44117 transform : ( params : ModelParams ) => {
45- const messages : { role : string ; parts : { text : string } [ ] } [ ] = [ ] ;
118+ const messages : GoogleMessage [ ] = [ ] ;
119+ let lastRole : GoogleMessageRole | undefined ;
46120
47121 params . messages ?. forEach ( ( message : ProviderMessage ) => {
48- const role = message . role === 'assistant' ? 'model' : 'user' ;
122+ const role = transformOpenAIRoleToGoogleRole ( message . role ) ;
49123 let parts = [ ] ;
50- if ( typeof message . content === 'string' ) {
124+
125+ if ( message . role === 'assistant' && message . tool_calls ) {
126+ message . tool_calls . forEach ( ( tool_call : ToolCall ) => {
127+ parts . push ( {
128+ functionCall : {
129+ name : tool_call . function . name ,
130+ args : JSON . parse ( tool_call . function . arguments )
131+ }
132+ } ) ;
133+ } ) ;
134+ } else if (
135+ message . role === 'tool' &&
136+ typeof message . content === 'string'
137+ ) {
51138 parts . push ( {
52- text : message . content
139+ functionResponse : {
140+ name : message . name ?? 'lb-random-tool-name' ,
141+ response : {
142+ content : message . content
143+ }
144+ }
53145 } ) ;
54- }
55-
56- if ( message . content && typeof message . content === 'object' ) {
146+ } else if (
147+ message . content &&
148+ typeof message . content === 'object'
149+ ) {
57150 message . content . forEach ( ( c : ContentType ) => {
58151 if ( c . type === 'text' ) {
59152 parts . push ( {
60153 text : c . text
61154 } ) ;
62155 }
63156 if ( c . type === 'image_url' ) {
64- parts . push ( {
65- inlineData : {
66- mimeType : 'image/jpeg' ,
67- data : c . image_url ?. url
68- }
69- } ) ;
157+ const { url } = c . image_url || { } ;
158+ if ( ! url ) return ;
159+
160+ // Handle different types of image URLs
161+ if ( url . startsWith ( 'data:' ) ) {
162+ const [ mimeTypeWithPrefix , base64Image ] =
163+ url . split ( ';base64,' ) ;
164+ const mimeType =
165+ mimeTypeWithPrefix . split ( ':' ) [ 1 ] ;
166+
167+ parts . push ( {
168+ inlineData : {
169+ mimeType : mimeType ,
170+ data : base64Image
171+ }
172+ } ) ;
173+ } else if (
174+ url . startsWith ( 'gs://' ) ||
175+ url . startsWith ( 'https://' ) ||
176+ url . startsWith ( 'http://' )
177+ ) {
178+ parts . push ( {
179+ fileData : {
180+ mimeType : getMimeType ( url ) ,
181+ fileUri : url
182+ }
183+ } ) ;
184+ } else {
185+ parts . push ( {
186+ inlineData : {
187+ mimeType : 'image/jpeg' ,
188+ data : c . image_url ?. url
189+ }
190+ } ) ;
191+ }
70192 }
71193 } ) ;
194+ } else if ( typeof message . content === 'string' ) {
195+ parts . push ( {
196+ text : message . content
197+ } ) ;
72198 }
73199
74- messages . push ( { role, parts } ) ;
200+ // Combine consecutive messages if they are from the same role
201+ // This takes care of the "Please ensure that multiturn requests alternate between user and model.
202+ // Also possible fix for "Please ensure that function call turn comes immediately after a user turn or after a function response turn." in parallel tool calls
203+ const shouldCombineMessages =
204+ lastRole === role && ! params . model ?. includes ( 'vision' ) ;
205+
206+ if ( shouldCombineMessages ) {
207+ messages [ messages . length - 1 ] . parts . push ( ...parts ) ;
208+ } else {
209+ messages . push ( { role, parts } ) ;
210+ }
211+ lastRole = role ;
75212 } ) ;
76213 return messages ;
77214 }
@@ -108,6 +245,36 @@ export const GoogleChatCompleteConfig: ProviderConfig = {
108245 } ) ;
109246 return [ { functionDeclarations } ] ;
110247 }
248+ } ,
249+ tool_choice : {
250+ param : 'tool_config' ,
251+ default : '' ,
252+ transform : ( params : ModelParams ) => {
253+ if ( params . tool_choice ) {
254+ const allowedFunctionNames : string [ ] = [ ] ;
255+ // If tool_choice is an object and type is function, add the function name to allowedFunctionNames
256+ if (
257+ typeof params . tool_choice === 'object' &&
258+ params . tool_choice . type === 'function'
259+ ) {
260+ allowedFunctionNames . push ( params . tool_choice . function . name ) ;
261+ }
262+ const toolConfig : GoogleToolConfig = {
263+ function_calling_config : {
264+ mode : transformToolChoiceForGemini ( params . tool_choice )
265+ }
266+ } ;
267+ // TODO: @msaaddev I think we can't have more than one function in tool_choice
268+ // but this will also handle the case if we have more than one function in tool_choice
269+
270+ // If tool_choice has functions, add the function names to allowedFunctionNames
271+ if ( allowedFunctionNames . length > 0 ) {
272+ toolConfig . function_calling_config . allowed_function_names =
273+ allowedFunctionNames ;
274+ }
275+ return toolConfig ;
276+ }
277+ }
111278 }
112279} ;
113280
@@ -146,6 +313,11 @@ interface GoogleGenerateContentResponse {
146313 probability : string ;
147314 } [ ] ;
148315 } ;
316+ usageMetadata : {
317+ promptTokenCount : number ;
318+ candidatesTokenCount : number ;
319+ totalTokenCount : number ;
320+ } ;
149321}
150322
151323export const GoogleChatCompleteResponseTransform : (
@@ -170,7 +342,6 @@ export const GoogleChatCompleteResponseTransform: (
170342 GOOGLE
171343 ) ;
172344 }
173-
174345 if ( 'candidates' in response ) {
175346 return {
176347 id : crypto . randomUUID ( ) ,
@@ -179,7 +350,7 @@ export const GoogleChatCompleteResponseTransform: (
179350 model : 'Unknown' ,
180351 provider : GOOGLE ,
181352 choices :
182- response . candidates ?. map ( ( generation , index ) => {
353+ response . candidates ?. map ( generation => {
183354 // In blocking mode: Google AI does not return content if response > max output tokens param
184355 // Test it by asking a big response while keeping maxtokens low ~ 50
185356 if (
@@ -203,28 +374,34 @@ export const GoogleChatCompleteResponseTransform: (
203374 } else if ( generation . content ?. parts [ 0 ] ?. functionCall ) {
204375 message = {
205376 role : 'assistant' ,
206- tool_calls : [
207- {
208- id : crypto . randomUUID ( ) ,
209- type : 'function' ,
210- function : {
211- name : generation . content . parts [ 0 ]
212- ?. functionCall . name ,
213- arguments : JSON . stringify (
214- generation . content . parts [ 0 ]
215- ?. functionCall . args
216- )
217- }
377+ content : null ,
378+ tool_calls : generation . content . parts . map ( part => {
379+ if ( part . functionCall ) {
380+ return {
381+ id : crypto . randomUUID ( ) ,
382+ type : 'function' ,
383+ function : {
384+ name : part . functionCall . name ,
385+ arguments : JSON . stringify (
386+ part . functionCall . args
387+ )
388+ }
389+ } ;
218390 }
219- ]
391+ } )
220392 } ;
221393 }
222394 return {
223395 message : message ,
224396 index : generation . index ,
225397 finish_reason : generation . finishReason
226398 } ;
227- } ) ?? [ ]
399+ } ) ?? [ ] ,
400+ usage : {
401+ prompt_tokens : response . usageMetadata . promptTokenCount ,
402+ completion_tokens : response . usageMetadata . candidatesTokenCount ,
403+ total_tokens : response . usageMetadata . totalTokenCount
404+ }
228405 } ;
229406 }
230407
@@ -262,7 +439,7 @@ export const GoogleChatCompleteStreamChunkTransform: (
262439 model : '' ,
263440 provider : 'google' ,
264441 choices :
265- parsedChunk . candidates ?. map ( ( generation , index ) => {
442+ parsedChunk . candidates ?. map ( generation => {
266443 let message : ProviderMessage = {
267444 role : 'assistant' ,
268445 content : ''
@@ -275,21 +452,23 @@ export const GoogleChatCompleteStreamChunkTransform: (
275452 } else if ( generation . content . parts [ 0 ] ?. functionCall ) {
276453 message = {
277454 role : 'assistant' ,
278- tool_calls : [
279- {
280- id : crypto . randomUUID ( ) ,
281- type : 'function' ,
282- index : 0 ,
283- function : {
284- name : generation . content . parts [ 0 ]
285- ?. functionCall . name ,
286- arguments : JSON . stringify (
287- generation . content . parts [ 0 ]
288- ?. functionCall . args
289- )
455+ tool_calls : generation . content . parts . map (
456+ ( part , idx ) => {
457+ if ( part . functionCall ) {
458+ return {
459+ index : idx ,
460+ id : crypto . randomUUID ( ) ,
461+ type : 'function' ,
462+ function : {
463+ name : part . functionCall . name ,
464+ arguments : JSON . stringify (
465+ part . functionCall . args
466+ )
467+ }
468+ } ;
290469 }
291470 }
292- ]
471+ )
293472 } ;
294473 }
295474 return {
0 commit comments