4040 Part ,
4141 ResponseValidationError ,
4242 Tool as VertexAITool ,
43+ ToolConfig ,
4344 )
4445except ImportError :
4546 GenerativeModel = None
@@ -137,20 +138,17 @@ def invoke(
137138 Returns:
138139 LLMResponse: The response from the LLM.
139140 """
140- system_message = [system_instruction ] if system_instruction is not None else []
141- self .model = GenerativeModel (
142- model_name = self .model_name ,
143- system_instruction = system_message ,
144- ** self .options ,
141+ model = self ._get_model (
142+ system_instruction = system_instruction ,
145143 )
146144 try :
147145 if isinstance (message_history , MessageHistory ):
148146 message_history = message_history .messages
149- messages = self .get_messages (input , message_history )
150- response = self . model .generate_content (messages , ** self . model_params )
151- return LLMResponse ( content = response . text )
147+ options = self ._get_call_params (input , message_history , tools = None )
148+ response = model .generate_content (** options )
149+ return self . _parse_content_response ( response )
152150 except ResponseValidationError as e :
153- raise LLMGenerationError (e )
151+ raise LLMGenerationError ("Error calling VertexAILLM" ) from e
154152
155153 async def ainvoke (
156154 self ,
@@ -172,65 +170,81 @@ async def ainvoke(
172170 try :
173171 if isinstance (message_history , MessageHistory ):
174172 message_history = message_history .messages
175- system_message = (
176- [system_instruction ] if system_instruction is not None else []
177- )
178- self .model = GenerativeModel (
179- model_name = self .model_name ,
180- system_instruction = system_message ,
181- ** self .options ,
173+ model = self ._get_model (
174+ system_instruction = system_instruction ,
182175 )
183- messages = self .get_messages (input , message_history )
184- response = await self .model .generate_content_async (
185- messages , ** self .model_params
186- )
187- return LLMResponse (content = response .text )
176+ options = self ._get_call_params (input , message_history , tools = None )
177+ response = await model .generate_content_async (** options )
178+ return self ._parse_content_response (response )
188179 except ResponseValidationError as e :
189- raise LLMGenerationError (e )
190-
191- def _to_vertexai_tool (self , tool : Tool ) -> VertexAITool :
192- return VertexAITool (
193- function_declarations = [
194- FunctionDeclaration (
195- name = tool .get_name (),
196- description = tool .get_description (),
197- parameters = tool .get_parameters (exclude = ["additional_properties" ]),
198- )
199- ]
180+ raise LLMGenerationError ("Error calling VertexAILLM" ) from e
181+
182+ def _to_vertexai_function_declaration (self , tool : Tool ) -> FunctionDeclaration :
183+ return FunctionDeclaration (
184+ name = tool .get_name (),
185+ description = tool .get_description (),
186+ parameters = tool .get_parameters (exclude = ["additional_properties" ]),
200187 )
201188
202189 def _get_llm_tools (
203190 self , tools : Optional [Sequence [Tool ]]
204191 ) -> Optional [list [VertexAITool ]]:
205192 if not tools :
206193 return None
207- return [self ._to_vertexai_tool (tool ) for tool in tools ]
194+ return [
195+ VertexAITool (
196+ function_declarations = [
197+ self ._to_vertexai_function_declaration (tool ) for tool in tools
198+ ]
199+ )
200+ ]
208201
209202 def _get_model (
210203 self ,
211204 system_instruction : Optional [str ] = None ,
212- tools : Optional [Sequence [Tool ]] = None ,
213205 ) -> GenerativeModel :
214206 system_message = [system_instruction ] if system_instruction is not None else []
215- vertex_ai_tools = self ._get_llm_tools (tools )
216207 model = GenerativeModel (
217208 model_name = self .model_name ,
218209 system_instruction = system_message ,
219- tools = vertex_ai_tools ,
220- ** self .options ,
221210 )
222211 return model
223212
213+ def _get_call_params (
214+ self ,
215+ input : str ,
216+ message_history : Optional [Union [List [LLMMessage ], MessageHistory ]],
217+ tools : Optional [Sequence [Tool ]],
218+ ) -> dict [str , Any ]:
219+ options = dict (self .options )
220+ if tools :
221+ # we want a tool back, remove generation_config if defined
222+ options .pop ("generation_config" , None )
223+ options ["tools" ] = self ._get_llm_tools (tools )
224+ if "tool_config" not in options :
225+ options ["tool_config" ] = ToolConfig (
226+ function_calling_config = ToolConfig .FunctionCallingConfig (
227+ mode = ToolConfig .FunctionCallingConfig .Mode .ANY ,
228+ )
229+ )
230+ else :
231+ # no tools, remove tool_config if defined
232+ options .pop ("tool_config" , None )
233+
234+ messages = self .get_messages (input , message_history )
235+ options ["contents" ] = messages
236+ return options
237+
224238 async def _acall_llm (
225239 self ,
226240 input : str ,
227241 message_history : Optional [Union [List [LLMMessage ], MessageHistory ]] = None ,
228242 system_instruction : Optional [str ] = None ,
229243 tools : Optional [Sequence [Tool ]] = None ,
230244 ) -> GenerationResponse :
231- model = self ._get_model (system_instruction = system_instruction , tools = tools )
232- messages = self .get_messages (input , message_history )
233- response = await model .generate_content_async (messages , ** self . model_params )
245+ model = self ._get_model (system_instruction = system_instruction )
246+ options = self ._get_call_params (input , message_history , tools )
247+ response = await model .generate_content_async (** options )
234248 return response
235249
236250 def _call_llm (
@@ -240,9 +254,9 @@ def _call_llm(
240254 system_instruction : Optional [str ] = None ,
241255 tools : Optional [Sequence [Tool ]] = None ,
242256 ) -> GenerationResponse :
243- model = self ._get_model (system_instruction = system_instruction , tools = tools )
244- messages = self .get_messages (input , message_history )
245- response = model .generate_content (messages , ** self . model_params )
257+ model = self ._get_model (system_instruction = system_instruction )
258+ options = self ._get_call_params (input , message_history , tools )
259+ response = model .generate_content (** options )
246260 return response
247261
248262 def _to_tool_call (self , function_call : FunctionCall ) -> ToolCall :
0 commit comments