|
2 | 2 | import os |
3 | 3 | import warnings |
4 | 4 | import re |
| 5 | +import inspect |
5 | 6 | from typing import Any, Dict, List, Optional, Union, Literal, Callable |
6 | 7 | from pydantic import BaseModel |
7 | 8 | import time |
|
16 | 17 | ReflectionOutput, |
17 | 18 | execute_sync_callback, |
18 | 19 | ) |
19 | | -from .model_capabilities import is_gemini_internal_tool |
20 | 20 | from rich.console import Console |
21 | 21 | from rich.live import Live |
22 | 22 |
|
@@ -380,6 +380,65 @@ def _parse_tool_call_arguments(self, tool_call: Dict, is_ollama: bool = False) - |
380 | 380 |
|
381 | 381 | return function_name, arguments, tool_call_id |
382 | 382 |
|
| 383 | + def _validate_and_filter_ollama_arguments(self, function_name: str, arguments: Dict[str, Any], available_tools: List) -> Dict[str, Any]: |
| 384 | + """ |
| 385 | + Validate and filter tool call arguments for Ollama provider. |
| 386 | + |
| 387 | + Ollama sometimes generates tool calls with mixed parameters where arguments |
| 388 | + from different functions are combined. This method validates arguments against |
| 389 | + the actual function signature and removes invalid parameters. |
| 390 | + |
| 391 | + Args: |
| 392 | + function_name: Name of the function to call |
| 393 | + arguments: Arguments provided in the tool call |
| 394 | + available_tools: List of available tool functions |
| 395 | + |
| 396 | + Returns: |
| 397 | + Filtered arguments dictionary with only valid parameters |
| 398 | + """ |
| 399 | + if not available_tools: |
| 400 | + logging.debug(f"[OLLAMA_FIX] No available tools provided for validation") |
| 401 | + return arguments |
| 402 | + |
| 403 | + # Find the target function |
| 404 | + target_function = None |
| 405 | + for tool in available_tools: |
| 406 | + tool_name = getattr(tool, '__name__', str(tool)) |
| 407 | + if tool_name == function_name: |
| 408 | + target_function = tool |
| 409 | + break |
| 410 | + |
| 411 | + if not target_function: |
| 412 | + logging.debug(f"[OLLAMA_FIX] Function {function_name} not found in available tools") |
| 413 | + return arguments |
| 414 | + |
| 415 | + try: |
| 416 | + # Get function signature |
| 417 | + sig = inspect.signature(target_function) |
| 418 | + valid_params = set(sig.parameters.keys()) |
| 419 | + |
| 420 | + # Filter arguments to only include valid parameters |
| 421 | + filtered_args = {} |
| 422 | + invalid_params = [] |
| 423 | + |
| 424 | + for param_name, param_value in arguments.items(): |
| 425 | + if param_name in valid_params: |
| 426 | + filtered_args[param_name] = param_value |
| 427 | + else: |
| 428 | + invalid_params.append(param_name) |
| 429 | + |
| 430 | + if invalid_params: |
| 431 | + logging.debug(f"[OLLAMA_FIX] Function {function_name} received invalid parameters: {invalid_params}") |
| 432 | + logging.debug(f"[OLLAMA_FIX] Valid parameters for {function_name}: {list(valid_params)}") |
| 433 | + logging.debug(f"[OLLAMA_FIX] Original arguments: {arguments}") |
| 434 | + logging.debug(f"[OLLAMA_FIX] Filtered arguments: {filtered_args}") |
| 435 | + |
| 436 | + return filtered_args |
| 437 | + |
| 438 | + except Exception as e: |
| 439 | + logging.debug(f"[OLLAMA_FIX] Error validating arguments for {function_name}: {e}") |
| 440 | + return arguments |
| 441 | + |
383 | 442 | def _needs_system_message_skip(self) -> bool: |
384 | 443 | """Check if this model requires skipping system messages""" |
385 | 444 | if not self.model: |
@@ -591,10 +650,14 @@ def _format_tools_for_litellm(self, tools: Optional[List[Any]]) -> Optional[List |
591 | 650 | if tool_def: |
592 | 651 | formatted_tools.append(tool_def) |
593 | 652 | # Handle Gemini internal tools (e.g., {"googleSearch": {}}, {"urlContext": {}}, {"codeExecution": {}}) |
594 | | - elif is_gemini_internal_tool(tool): |
| 653 | + elif isinstance(tool, dict) and len(tool) == 1: |
595 | 654 | tool_name = next(iter(tool.keys())) |
596 | | - logging.debug(f"Using Gemini internal tool: {tool_name}") |
597 | | - formatted_tools.append(tool) |
| 655 | + gemini_internal_tools = {'googleSearch', 'urlContext', 'codeExecution'} |
| 656 | + if tool_name in gemini_internal_tools: |
| 657 | + logging.debug(f"Using Gemini internal tool: {tool_name}") |
| 658 | + formatted_tools.append(tool) |
| 659 | + else: |
| 660 | + logging.debug(f"Skipping unknown tool: {tool_name}") |
598 | 661 | else: |
599 | 662 | logging.debug(f"Skipping tool of unsupported type: {type(tool)}") |
600 | 663 |
|
@@ -959,6 +1022,10 @@ def get_response( |
959 | 1022 | is_ollama = self._is_ollama_provider() |
960 | 1023 | function_name, arguments, tool_call_id = self._extract_tool_call_info(tool_call, is_ollama) |
961 | 1024 |
|
| 1025 | + # Validate and filter arguments for Ollama provider |
| 1026 | + if is_ollama and tools: |
| 1027 | + arguments = self._validate_and_filter_ollama_arguments(function_name, arguments, tools) |
| 1028 | + |
962 | 1029 | logging.debug(f"[TOOL_EXEC_DEBUG] About to execute tool {function_name} with args: {arguments}") |
963 | 1030 | tool_result = execute_tool_fn(function_name, arguments) |
964 | 1031 | logging.debug(f"[TOOL_EXEC_DEBUG] Tool execution result: {tool_result}") |
@@ -1610,6 +1677,10 @@ async def get_response_async( |
1610 | 1677 | is_ollama = self._is_ollama_provider() |
1611 | 1678 | function_name, arguments, tool_call_id = self._extract_tool_call_info(tool_call, is_ollama) |
1612 | 1679 |
|
| 1680 | + # Validate and filter arguments for Ollama provider |
| 1681 | + if is_ollama and tools: |
| 1682 | + arguments = self._validate_and_filter_ollama_arguments(function_name, arguments, tools) |
| 1683 | + |
1613 | 1684 | tool_result = await execute_tool_fn(function_name, arguments) |
1614 | 1685 | tool_results.append(tool_result) # Store the result |
1615 | 1686 |
|
|
0 commit comments