Skip to content

Commit faaff9f

Browse files
committed
2 parents bc13d78 + d114462 commit faaff9f

File tree

4 files changed

+251
-8
lines changed

4 files changed

+251
-8
lines changed

src/praisonai-agents/praisonaiagents/llm/llm.py

Lines changed: 75 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
import os
33
import warnings
44
import re
5+
import inspect
56
from typing import Any, Dict, List, Optional, Union, Literal, Callable
67
from pydantic import BaseModel
78
import time
@@ -16,7 +17,6 @@
1617
ReflectionOutput,
1718
execute_sync_callback,
1819
)
19-
from .model_capabilities import is_gemini_internal_tool
2020
from rich.console import Console
2121
from rich.live import Live
2222

@@ -380,6 +380,65 @@ def _parse_tool_call_arguments(self, tool_call: Dict, is_ollama: bool = False) -
380380

381381
return function_name, arguments, tool_call_id
382382

383+
def _validate_and_filter_ollama_arguments(self, function_name: str, arguments: Dict[str, Any], available_tools: List) -> Dict[str, Any]:
384+
"""
385+
Validate and filter tool call arguments for Ollama provider.
386+
387+
Ollama sometimes generates tool calls with mixed parameters where arguments
388+
from different functions are combined. This method validates arguments against
389+
the actual function signature and removes invalid parameters.
390+
391+
Args:
392+
function_name: Name of the function to call
393+
arguments: Arguments provided in the tool call
394+
available_tools: List of available tool functions
395+
396+
Returns:
397+
Filtered arguments dictionary with only valid parameters
398+
"""
399+
if not available_tools:
400+
logging.debug(f"[OLLAMA_FIX] No available tools provided for validation")
401+
return arguments
402+
403+
# Find the target function
404+
target_function = None
405+
for tool in available_tools:
406+
tool_name = getattr(tool, '__name__', str(tool))
407+
if tool_name == function_name:
408+
target_function = tool
409+
break
410+
411+
if not target_function:
412+
logging.debug(f"[OLLAMA_FIX] Function {function_name} not found in available tools")
413+
return arguments
414+
415+
try:
416+
# Get function signature
417+
sig = inspect.signature(target_function)
418+
valid_params = set(sig.parameters.keys())
419+
420+
# Filter arguments to only include valid parameters
421+
filtered_args = {}
422+
invalid_params = []
423+
424+
for param_name, param_value in arguments.items():
425+
if param_name in valid_params:
426+
filtered_args[param_name] = param_value
427+
else:
428+
invalid_params.append(param_name)
429+
430+
if invalid_params:
431+
logging.debug(f"[OLLAMA_FIX] Function {function_name} received invalid parameters: {invalid_params}")
432+
logging.debug(f"[OLLAMA_FIX] Valid parameters for {function_name}: {list(valid_params)}")
433+
logging.debug(f"[OLLAMA_FIX] Original arguments: {arguments}")
434+
logging.debug(f"[OLLAMA_FIX] Filtered arguments: {filtered_args}")
435+
436+
return filtered_args
437+
438+
except Exception as e:
439+
logging.debug(f"[OLLAMA_FIX] Error validating arguments for {function_name}: {e}")
440+
return arguments
441+
383442
def _needs_system_message_skip(self) -> bool:
384443
"""Check if this model requires skipping system messages"""
385444
if not self.model:
@@ -591,10 +650,14 @@ def _format_tools_for_litellm(self, tools: Optional[List[Any]]) -> Optional[List
591650
if tool_def:
592651
formatted_tools.append(tool_def)
593652
# Handle Gemini internal tools (e.g., {"googleSearch": {}}, {"urlContext": {}}, {"codeExecution": {}})
594-
elif is_gemini_internal_tool(tool):
653+
elif isinstance(tool, dict) and len(tool) == 1:
595654
tool_name = next(iter(tool.keys()))
596-
logging.debug(f"Using Gemini internal tool: {tool_name}")
597-
formatted_tools.append(tool)
655+
gemini_internal_tools = {'googleSearch', 'urlContext', 'codeExecution'}
656+
if tool_name in gemini_internal_tools:
657+
logging.debug(f"Using Gemini internal tool: {tool_name}")
658+
formatted_tools.append(tool)
659+
else:
660+
logging.debug(f"Skipping unknown tool: {tool_name}")
598661
else:
599662
logging.debug(f"Skipping tool of unsupported type: {type(tool)}")
600663

@@ -959,6 +1022,10 @@ def get_response(
9591022
is_ollama = self._is_ollama_provider()
9601023
function_name, arguments, tool_call_id = self._extract_tool_call_info(tool_call, is_ollama)
9611024

1025+
# Validate and filter arguments for Ollama provider
1026+
if is_ollama and tools:
1027+
arguments = self._validate_and_filter_ollama_arguments(function_name, arguments, tools)
1028+
9621029
logging.debug(f"[TOOL_EXEC_DEBUG] About to execute tool {function_name} with args: {arguments}")
9631030
tool_result = execute_tool_fn(function_name, arguments)
9641031
logging.debug(f"[TOOL_EXEC_DEBUG] Tool execution result: {tool_result}")
@@ -1610,6 +1677,10 @@ async def get_response_async(
16101677
is_ollama = self._is_ollama_provider()
16111678
function_name, arguments, tool_call_id = self._extract_tool_call_info(tool_call, is_ollama)
16121679

1680+
# Validate and filter arguments for Ollama provider
1681+
if is_ollama and tools:
1682+
arguments = self._validate_and_filter_ollama_arguments(function_name, arguments, tools)
1683+
16131684
tool_result = await execute_tool_fn(function_name, arguments)
16141685
tool_results.append(tool_result) # Store the result
16151686

src/praisonai-agents/praisonaiagents/llm/openai_client.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
from rich.console import Console
1919
from rich.live import Live
2020
import inspect
21-
from .model_capabilities import is_gemini_internal_tool
2221

2322
# Constants
2423
LOCAL_SERVER_API_KEY_PLACEHOLDER = "not-needed"
@@ -407,10 +406,14 @@ def format_tools(self, tools: Optional[List[Any]]) -> Optional[List[Dict]]:
407406
if tool_def:
408407
formatted_tools.append(tool_def)
409408
# Handle Gemini internal tools (e.g., {"googleSearch": {}}, {"urlContext": {}}, {"codeExecution": {}})
410-
elif is_gemini_internal_tool(tool):
409+
elif isinstance(tool, dict) and len(tool) == 1:
411410
tool_name = next(iter(tool.keys()))
412-
logging.debug(f"Using Gemini internal tool: {tool_name}")
413-
formatted_tools.append(tool)
411+
gemini_internal_tools = {'googleSearch', 'urlContext', 'codeExecution'}
412+
if tool_name in gemini_internal_tools:
413+
logging.debug(f"Using Gemini internal tool: {tool_name}")
414+
formatted_tools.append(tool)
415+
else:
416+
logging.debug(f"Skipping unknown tool: {tool_name}")
414417
else:
415418
logging.debug(f"Skipping tool of unsupported type: {type(tool)}")
416419

Lines changed: 129 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,129 @@
1+
#!/usr/bin/env python3
2+
"""
3+
Test script to verify Ollama sequential tool calling argument mixing fix.
4+
5+
This test validates that the parameter validation and filtering fix correctly handles
6+
the case where Ollama generates tool calls with mixed parameters from different functions.
7+
"""
8+
9+
import logging
10+
from praisonaiagents.llm.llm import LLM
11+
12+
# Enable debug logging
13+
logging.basicConfig(level=logging.DEBUG)
14+
15+
# Test functions matching the issue description
16+
def get_stock_price(company_name: str) -> str:
17+
"""
18+
Get the stock price of a company
19+
20+
Args:
21+
company_name (str): The name of the company
22+
23+
Returns:
24+
str: The stock price of the company
25+
"""
26+
return f"The stock price of {company_name} is 100"
27+
28+
def multiply(a: int, b: int) -> int:
29+
"""
30+
Multiply two numbers
31+
"""
32+
return a * b
33+
34+
def test_ollama_argument_validation():
35+
"""
36+
Test the Ollama argument validation and filtering functionality.
37+
"""
38+
print("Testing Ollama argument validation and filtering...")
39+
40+
llm = LLM(model="ollama/llama3.2")
41+
tools = [get_stock_price, multiply]
42+
43+
# Test case 1: Valid arguments (should pass through unchanged)
44+
print("\n1. Testing valid arguments:")
45+
valid_args = {"a": 100, "b": 2}
46+
filtered_args = llm._validate_and_filter_ollama_arguments("multiply", valid_args, tools)
47+
print(f"Original: {valid_args}")
48+
print(f"Filtered: {filtered_args}")
49+
assert filtered_args == valid_args, "Valid arguments should pass through unchanged"
50+
print("✅ Valid arguments test passed")
51+
52+
# Test case 2: Mixed arguments (the actual issue from #918)
53+
print("\n2. Testing mixed arguments (the main issue):")
54+
mixed_args = {"a": "get_stock_price", "company_name": "Google", "b": "2"}
55+
filtered_args = llm._validate_and_filter_ollama_arguments("multiply", mixed_args, tools)
56+
expected_filtered = {"a": "get_stock_price", "b": "2"} # Should remove 'company_name'
57+
print(f"Original: {mixed_args}")
58+
print(f"Filtered: {filtered_args}")
59+
print(f"Expected: {expected_filtered}")
60+
assert filtered_args == expected_filtered, f"Expected {expected_filtered}, got {filtered_args}"
61+
print("✅ Mixed arguments filtering test passed")
62+
63+
# Test case 3: All invalid arguments
64+
print("\n3. Testing all invalid arguments:")
65+
invalid_args = {"invalid_param1": "value1", "invalid_param2": "value2"}
66+
filtered_args = llm._validate_and_filter_ollama_arguments("multiply", invalid_args, tools)
67+
expected_empty = {}
68+
print(f"Original: {invalid_args}")
69+
print(f"Filtered: {filtered_args}")
70+
assert filtered_args == expected_empty, "All invalid arguments should be filtered out"
71+
print("✅ Invalid arguments filtering test passed")
72+
73+
# Test case 4: Function not found in tools
74+
print("\n4. Testing function not found:")
75+
some_args = {"param": "value"}
76+
filtered_args = llm._validate_and_filter_ollama_arguments("nonexistent_function", some_args, tools)
77+
print(f"Original: {some_args}")
78+
print(f"Filtered: {filtered_args}")
79+
assert filtered_args == some_args, "Arguments should pass through if function not found"
80+
print("✅ Function not found test passed")
81+
82+
# Test case 5: Empty tools list
83+
print("\n5. Testing empty tools list:")
84+
some_args = {"param": "value"}
85+
filtered_args = llm._validate_and_filter_ollama_arguments("multiply", some_args, [])
86+
print(f"Original: {some_args}")
87+
print(f"Filtered: {filtered_args}")
88+
assert filtered_args == some_args, "Arguments should pass through if no tools provided"
89+
print("✅ Empty tools test passed")
90+
91+
print("\n🎉 All Ollama argument validation tests passed!")
92+
return True
93+
94+
def test_provider_detection():
95+
"""
96+
Test the Ollama provider detection functionality.
97+
"""
98+
print("\nTesting Ollama provider detection...")
99+
100+
# Test Ollama provider detection
101+
ollama_llm = LLM(model="ollama/llama3.2")
102+
assert ollama_llm._is_ollama_provider(), "Should detect ollama/ prefix"
103+
print("✅ Ollama prefix detection works")
104+
105+
# Test non-Ollama provider
106+
openai_llm = LLM(model="gpt-4o-mini")
107+
assert not openai_llm._is_ollama_provider(), "Should not detect OpenAI as Ollama"
108+
print("✅ Non-Ollama provider detection works")
109+
110+
print("✅ Provider detection tests passed!")
111+
return True
112+
113+
if __name__ == "__main__":
114+
print("Running Ollama sequential tool calling fix tests...")
115+
print("=" * 60)
116+
117+
# Run tests
118+
try:
119+
test_provider_detection()
120+
test_ollama_argument_validation()
121+
122+
print("\n" + "=" * 60)
123+
print("🎉 ALL TESTS PASSED!")
124+
print("The Ollama sequential tool calling argument mixing issue has been fixed!")
125+
126+
except Exception as e:
127+
print(f"\n❌ TEST FAILED: {e}")
128+
import traceback
129+
traceback.print_exc()

test_gemini_tools.py

Lines changed: 40 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,40 @@
1+
#!/usr/bin/env python3
2+
"""Test script to verify the refactored Gemini tools logic."""
3+
4+
import logging
5+
import sys
6+
import os
7+
8+
# Setup logging to see debug messages
9+
logging.basicConfig(level=logging.DEBUG, format='%(levelname)s: %(message)s')
10+
11+
# Add the src directory to the Python path
12+
sys.path.insert(0, os.path.join(os.path.dirname(__file__), 'src', 'praisonai-agents'))
13+
14+
try:
15+
from praisonaiagents.llm.llm import LLM
16+
from praisonaiagents.llm.openai_client import OpenAIClient
17+
18+
print("Testing LLM tool formatting...")
19+
llm = LLM(model="gpt-4o-mini")
20+
tools = [
21+
{'googleSearch': {}}, # Valid Gemini tool
22+
{'urlContext': {}}, # Valid Gemini tool
23+
{'codeExecution': {}}, # Valid Gemini tool
24+
{'unknown': {}} # Invalid tool - should be skipped
25+
]
26+
27+
formatted = llm._format_tools_for_litellm(tools)
28+
print(f"LLM formatted tools ({len(formatted)} tools):", formatted)
29+
30+
print("\nTesting OpenAI client tool formatting...")
31+
client = OpenAIClient(api_key="not-needed")
32+
formatted = client.format_tools(tools)
33+
print(f"OpenAI client formatted tools ({len(formatted)} tools):", formatted)
34+
35+
print("\nTest completed successfully!")
36+
37+
except Exception as e:
38+
import traceback
39+
print(f"Error: {e}")
40+
traceback.print_exc()

0 commit comments

Comments
 (0)