Skip to content

Commit a86e945

Browse files
committed
feat: add gemini_tools field to GeminiModel with validation and tests
Add support for Gemini-specific tools like GoogleSearch and CodeExecution, with validation to prevent FunctionDeclarations and comprehensive test coverage.
1 parent 7cd10b9 commit a86e945

File tree

3 files changed

+145
-1
lines changed

3 files changed

+145
-1
lines changed

src/strands/models/gemini.py

Lines changed: 42 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,30 @@
2424
T = TypeVar("T", bound=pydantic.BaseModel)
2525

2626

27+
def _validate_gemini_tools(gemini_tools: list[genai.types.Tool]) -> None:
28+
"""Validate that gemini_tools does not contain FunctionDeclarations.
29+
30+
Gemini-specific tools should only include tools that cannot be represented
31+
as FunctionDeclarations (e.g., GoogleSearch, CodeExecution, ComputerUse).
32+
Standard function calling tools should use the tools interface instead.
33+
34+
Args:
35+
gemini_tools: List of Gemini tools to validate
36+
37+
Raises:
38+
ValueError: If any tool contains function_declarations
39+
"""
40+
for tool in gemini_tools:
41+
# Check if the tool has function_declarations attribute and it's not empty
42+
if hasattr(tool, "function_declarations") and tool.function_declarations:
43+
raise ValueError(
44+
"gemini_tools should not contain FunctionDeclarations. "
45+
"Use the standard tools interface for function calling tools. "
46+
"gemini_tools is reserved for Gemini-specific tools like "
47+
"GoogleSearch, CodeExecution, ComputerUse, UrlContext, and FileSearch."
48+
)
49+
50+
2751
class GeminiModel(Model):
2852
"""Google Gemini model provider implementation.
2953
@@ -40,10 +64,16 @@ class GeminiConfig(TypedDict, total=False):
4064
params: Additional model parameters (e.g., temperature).
4165
For a complete list of supported parameters, see
4266
https://ai.google.dev/api/generate-content#generationconfig.
67+
gemini_tools: Gemini-specific tools that are not FunctionDeclarations
68+
(e.g., GoogleSearch, CodeExecution, ComputerUse, UrlContext, FileSearch).
69+
Use the standard tools interface for function calling tools.
70+
For a complete list of supported tools, see
71+
https://ai.google.dev/api/caching#Tool
4372
"""
4473

4574
model_id: Required[str]
4675
params: dict[str, Any]
76+
gemini_tools: list[genai.types.Tool]
4777

4878
def __init__(
4979
self,
@@ -61,6 +91,10 @@ def __init__(
6191
validate_config_keys(model_config, GeminiModel.GeminiConfig)
6292
self.config = GeminiModel.GeminiConfig(**model_config)
6393

94+
# Validate gemini_tools if provided
95+
if "gemini_tools" in self.config:
96+
_validate_gemini_tools(self.config["gemini_tools"])
97+
6498
logger.debug("config=<%s> | initializing", self.config)
6599

66100
self.client_args = client_args or {}
@@ -72,6 +106,10 @@ def update_config(self, **model_config: Unpack[GeminiConfig]) -> None: # type:
72106
Args:
73107
**model_config: Configuration overrides.
74108
"""
109+
# Validate gemini_tools if provided
110+
if "gemini_tools" in model_config:
111+
_validate_gemini_tools(model_config["gemini_tools"])
112+
75113
self.config.update(model_config)
76114

77115
@override
@@ -181,7 +219,7 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
181219
Return:
182220
Gemini tool list.
183221
"""
184-
return [
222+
tools = [
185223
genai.types.Tool(
186224
function_declarations=[
187225
genai.types.FunctionDeclaration(
@@ -193,6 +231,9 @@ def _format_request_tools(self, tool_specs: Optional[list[ToolSpec]]) -> list[ge
193231
],
194232
),
195233
]
234+
if self.config.get("gemini_tools"):
235+
tools.extend(self.config.get("gemini_tools", []))
236+
return tools
196237

197238
def _format_request_config(
198239
self,

tests/strands/models/test_gemini.py

Lines changed: 83 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -621,3 +621,86 @@ async def test_structured_output(gemini_client, model, messages, model_id, weath
621621
"model": model_id,
622622
}
623623
gemini_client.aio.models.generate_content.assert_called_with(**exp_request)
624+
625+
626+
def test_gemini_tools_validation_rejects_function_declarations(model_id):
627+
tool_with_function_declarations = genai.types.Tool(
628+
function_declarations=[
629+
genai.types.FunctionDeclaration(
630+
name="test_function",
631+
description="A test function",
632+
)
633+
]
634+
)
635+
636+
with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
637+
GeminiModel(model_id=model_id, gemini_tools=[tool_with_function_declarations])
638+
639+
640+
def test_gemini_tools_validation_allows_non_function_tools(model_id):
641+
tool_with_google_search = genai.types.Tool(google_search=genai.types.GoogleSearch())
642+
643+
model = GeminiModel(model_id=model_id, gemini_tools=[tool_with_google_search])
644+
assert "gemini_tools" in model.config
645+
646+
647+
def test_gemini_tools_validation_on_update_config(model):
648+
tool_with_function_declarations = genai.types.Tool(
649+
function_declarations=[
650+
genai.types.FunctionDeclaration(
651+
name="test_function",
652+
description="A test function",
653+
)
654+
]
655+
)
656+
657+
with pytest.raises(ValueError, match="gemini_tools should not contain FunctionDeclarations"):
658+
model.update_config(gemini_tools=[tool_with_function_declarations])
659+
660+
661+
@pytest.mark.asyncio
662+
async def test_stream_request_with_gemini_tools(gemini_client, messages, model_id):
663+
google_search_tool = genai.types.Tool(google_search=genai.types.GoogleSearch())
664+
model = GeminiModel(model_id=model_id, gemini_tools=[google_search_tool])
665+
666+
await anext(model.stream(messages))
667+
668+
exp_request = {
669+
"config": {
670+
"tools": [
671+
{"function_declarations": []},
672+
{"google_search": {}},
673+
]
674+
},
675+
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
676+
"model": model_id,
677+
}
678+
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)
679+
680+
681+
@pytest.mark.asyncio
682+
async def test_stream_request_with_gemini_tools_and_function_tools(gemini_client, messages, tool_spec, model_id):
683+
code_execution_tool = genai.types.Tool(code_execution=genai.types.ToolCodeExecution())
684+
model = GeminiModel(model_id=model_id, gemini_tools=[code_execution_tool])
685+
686+
await anext(model.stream(messages, tool_specs=[tool_spec]))
687+
688+
exp_request = {
689+
"config": {
690+
"tools": [
691+
{
692+
"function_declarations": [
693+
{
694+
"description": tool_spec["description"],
695+
"name": tool_spec["name"],
696+
"parameters_json_schema": tool_spec["inputSchema"]["json"],
697+
}
698+
]
699+
},
700+
{"code_execution": {}},
701+
]
702+
},
703+
"contents": [{"parts": [{"text": "test"}], "role": "user"}],
704+
"model": model_id,
705+
}
706+
gemini_client.aio.models.generate_content_stream.assert_called_with(**exp_request)

tests_integ/models/test_model_gemini.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import pydantic
44
import pytest
5+
from google import genai
56

67
import strands
78
from strands import Agent
@@ -21,6 +22,16 @@ def model():
2122
)
2223

2324

25+
@pytest.fixture
26+
def gemini_tool_model():
27+
return GeminiModel(
28+
client_args={"api_key": os.getenv("GOOGLE_API_KEY")},
29+
model_id="gemini-2.5-flash",
30+
params={"temperature": 0.15}, # Lower temperature for consistent test behavior
31+
gemini_tools=[genai.types.Tool(code_execution=genai.types.ToolCodeExecution())],
32+
)
33+
34+
2435
@pytest.fixture
2536
def tools():
2637
@strands.tool
@@ -175,3 +186,12 @@ def test_agent_structured_output_image_input(assistant_agent, yellow_img, yellow
175186
tru_color = assistant_agent.structured_output(type(yellow_color), content)
176187
exp_color = yellow_color
177188
assert tru_color == exp_color
189+
190+
191+
def test_agent_with_gemini_code_execution_tool(gemini_tool_model):
192+
# FIXME: Should verify tool usage history, but currently validates by solving a complex calculation
193+
system_prompt = "Execute calculations and output only the numerical result. No explanations or units needed."
194+
agent = Agent(model=gemini_tool_model, system_prompt=system_prompt)
195+
result = agent("Calculate 931567 * 81364")
196+
text = result.message.get("content", [{}])[0].get("text", "")
197+
assert "75796017388" in text

0 commit comments

Comments
 (0)