Skip to content

Commit 0df6759

Browse files
Jacksunweicopybara-github
authored andcommitted
chore: Checks gemini version for 2 and above for gemini-builtin tools
PiperOrigin-RevId: 820848897
1 parent 8b3ed05 commit 0df6759

File tree

5 files changed

+76
-48
lines changed

5 files changed

+76
-48
lines changed

src/google/adk/code_executors/built_in_code_executor.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919

2020
from ..agents.invocation_context import InvocationContext
2121
from ..models import LlmRequest
22-
from ..utils.model_name_utils import is_gemini_2_model
22+
from ..utils.model_name_utils import is_gemini_2_or_above
2323
from .base_code_executor import BaseCodeExecutor
2424
from .code_execution_utils import CodeExecutionInput
2525
from .code_execution_utils import CodeExecutionResult
@@ -42,7 +42,7 @@ def execute_code(
4242

4343
def process_llm_request(self, llm_request: LlmRequest) -> None:
4444
"""Pre-process the LLM request for Gemini 2.0+ models to use the code execution tool."""
45-
if is_gemini_2_model(llm_request.model):
45+
if is_gemini_2_or_above(llm_request.model):
4646
llm_request.config = llm_request.config or types.GenerateContentConfig()
4747
llm_request.config.tools = llm_request.config.tools or []
4848
llm_request.config.tools.append(

src/google/adk/tools/retrieval/vertex_ai_rag_retrieval.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,7 @@
2424
from typing_extensions import override
2525
from vertexai.preview import rag
2626

27-
from ...utils.model_name_utils import is_gemini_2_model
27+
from ...utils.model_name_utils import is_gemini_2_or_above
2828
from ..tool_context import ToolContext
2929
from .base_retrieval_tool import BaseRetrievalTool
3030

@@ -63,7 +63,7 @@ async def process_llm_request(
6363
llm_request: LlmRequest,
6464
) -> None:
6565
# Use Gemini built-in Vertex AI RAG tool for Gemini 2 models.
66-
if is_gemini_2_model(llm_request.model):
66+
if is_gemini_2_or_above(llm_request.model):
6767
llm_request.config = (
6868
types.GenerateContentConfig()
6969
if not llm_request.config

src/google/adk/tools/url_context_tool.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
from typing_extensions import override
2121

2222
from ..utils.model_name_utils import is_gemini_1_model
23-
from ..utils.model_name_utils import is_gemini_2_model
23+
from ..utils.model_name_utils import is_gemini_2_or_above
2424
from .base_tool import BaseTool
2525
from .tool_context import ToolContext
2626

@@ -50,7 +50,7 @@ async def process_llm_request(
5050
llm_request.config.tools = llm_request.config.tools or []
5151
if is_gemini_1_model(llm_request.model):
5252
raise ValueError('Url context tool can not be used in Gemini 1.x.')
53-
elif is_gemini_2_model(llm_request.model):
53+
elif is_gemini_2_or_above(llm_request.model):
5454
llm_request.config.tools.append(
5555
types.Tool(url_context=types.UrlContext())
5656
)

src/google/adk/utils/model_name_utils.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,9 @@
1919
import re
2020
from typing import Optional
2121

22+
from packaging.version import InvalidVersion
23+
from packaging.version import Version
24+
2225

2326
def extract_model_name(model_string: str) -> str:
2427
"""Extract the actual model name from either simple or path-based format.
@@ -74,17 +77,29 @@ def is_gemini_1_model(model_string: Optional[str]) -> bool:
7477
return re.match(r'^gemini-1\.\d+', model_name) is not None
7578

7679

77-
def is_gemini_2_model(model_string: Optional[str]) -> bool:
78-
"""Check if the model is a Gemini 2.x model using regex patterns.
80+
def is_gemini_2_or_above(model_string: Optional[str]) -> bool:
81+
"""Check if the model is a Gemini 2.0 or newer model using semantic versions.
7982
8083
Args:
8184
model_string: Either a simple model name or path-based model name
8285
8386
Returns:
84-
True if it's a Gemini 2.x model, False otherwise
87+
True if it's a Gemini 2.0+ model, False otherwise
8588
"""
8689
if not model_string:
8790
return False
8891

8992
model_name = extract_model_name(model_string)
90-
return re.match(r'^gemini-2\.\d+', model_name) is not None
93+
if not model_name.startswith('gemini-'):
94+
return False
95+
96+
version_string = model_name[len('gemini-') :].split('-', 1)[0]
97+
if not version_string:
98+
return False
99+
100+
try:
101+
parsed_version = Version(version_string)
102+
except InvalidVersion:
103+
return False
104+
105+
return parsed_version.major >= 2

tests/unittests/utils/test_model_name_utils.py

Lines changed: 51 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from google.adk.utils.model_name_utils import extract_model_name
1818
from google.adk.utils.model_name_utils import is_gemini_1_model
19-
from google.adk.utils.model_name_utils import is_gemini_2_model
19+
from google.adk.utils.model_name_utils import is_gemini_2_or_above
2020
from google.adk.utils.model_name_utils import is_gemini_model
2121

2222

@@ -165,46 +165,51 @@ def test_is_gemini_1_model_edge_cases(self):
165165

166166

167167
class TestIsGemini2Model:
168-
"""Test the is_gemini_2_model function."""
169-
170-
def test_is_gemini_2_model_simple_names(self):
171-
"""Test Gemini 2.x model detection with simple model names."""
172-
assert is_gemini_2_model('gemini-2.0-flash') is True
173-
assert is_gemini_2_model('gemini-2.5-pro') is True
174-
assert is_gemini_2_model('gemini-2.0-flash-001') is True
175-
assert is_gemini_2_model('gemini-2.9-experimental') is True
176-
assert is_gemini_2_model('gemini-1.5-flash') is False
177-
assert is_gemini_2_model('gemini-1.0-pro') is False
178-
assert is_gemini_2_model('gemini-3.0-pro') is False # Only 2.x versions
179-
assert is_gemini_2_model('claude-3-sonnet') is False
180-
181-
def test_is_gemini_2_model_path_based_names(self):
182-
"""Test Gemini 2.x model detection with path-based model names."""
168+
"""Test the is_gemini_2_or_above function."""
169+
170+
def test_is_gemini_2_or_above_simple_names(self):
171+
"""Test Gemini 2.0+ model detection with simple model names."""
172+
assert is_gemini_2_or_above('gemini-2.0-flash') is True
173+
assert is_gemini_2_or_above('gemini-2.5-pro') is True
174+
assert is_gemini_2_or_above('gemini-2.0-flash-001') is True
175+
assert is_gemini_2_or_above('gemini-2.9-experimental') is True
176+
assert is_gemini_2_or_above('gemini-2-pro') is True
177+
assert is_gemini_2_or_above('gemini-2') is True
178+
assert is_gemini_2_or_above('gemini-3.0-pro') is True
179+
assert is_gemini_2_or_above('gemini-1.5-flash') is False
180+
assert is_gemini_2_or_above('gemini-1.0-pro') is False
181+
assert is_gemini_2_or_above('claude-3-sonnet') is False
182+
183+
def test_is_gemini_2_or_above_path_based_names(self):
184+
"""Test Gemini 2.0+ model detection with path-based model names."""
183185
gemini_2_path = 'projects/265104255505/locations/us-central1/publishers/google/models/gemini-2.0-flash-001'
184-
assert is_gemini_2_model(gemini_2_path) is True
186+
assert is_gemini_2_or_above(gemini_2_path) is True
185187

186188
gemini_2_path_2 = 'projects/12345/locations/us-east1/publishers/google/models/gemini-2.5-pro-preview'
187-
assert is_gemini_2_model(gemini_2_path_2) is True
189+
assert is_gemini_2_or_above(gemini_2_path_2) is True
188190

189191
gemini_1_path = 'projects/265104255505/locations/us-central1/publishers/google/models/gemini-1.5-flash-001'
190-
assert is_gemini_2_model(gemini_1_path) is False
192+
assert is_gemini_2_or_above(gemini_1_path) is False
191193

192-
def test_is_gemini_2_model_edge_cases(self):
193-
"""Test edge cases for Gemini 2.x model detection."""
194+
gemini_3_path = 'projects/12345/locations/us-east1/publishers/google/models/gemini-3.0-pro'
195+
assert is_gemini_2_or_above(gemini_3_path) is True
196+
197+
def test_is_gemini_2_or_above_edge_cases(self):
198+
"""Test edge cases for Gemini 2.0+ model detection."""
194199
# Test with None
195-
assert is_gemini_2_model(None) is False
200+
assert is_gemini_2_or_above(None) is False
196201

197202
# Test with empty string
198-
assert is_gemini_2_model('') is False
203+
assert is_gemini_2_or_above('') is False
199204

200205
# Test with model names containing gemini-2 but not starting with it
201-
assert is_gemini_2_model('my-gemini-2.5-model') is False
202-
assert is_gemini_2_model('custom-gemini-2.0-flash') is False
206+
assert is_gemini_2_or_above('my-gemini-2.5-model') is False
207+
assert is_gemini_2_or_above('custom-gemini-2.0-flash') is False
203208

204209
# Test with invalid versions
205-
assert is_gemini_2_model('gemini-2') is False # Missing dot
206-
assert is_gemini_2_model('gemini-2-pro') is False # Missing dot
207-
assert is_gemini_2_model('gemini-2.') is False # Missing version number
210+
assert is_gemini_2_or_above('gemini-2.') is False # Missing version number
211+
assert is_gemini_2_or_above('gemini-0.9-test') is False
212+
assert is_gemini_2_or_above('gemini-one') is False
208213

209214

210215
class TestModelNameUtilsIntegration:
@@ -216,32 +221,34 @@ def test_model_classification_consistency(self):
216221
'gemini-1.5-flash',
217222
'gemini-2.0-flash',
218223
'gemini-2.5-pro',
224+
'gemini-3.0-pro',
219225
'projects/123/locations/us-central1/publishers/google/models/gemini-1.5-pro',
220226
'projects/123/locations/us-central1/publishers/google/models/gemini-2.0-flash',
227+
'projects/123/locations/us-central1/publishers/google/models/gemini-3.0-pro',
221228
'claude-3-sonnet',
222229
'gpt-4',
223230
]
224231

225232
for model in test_models:
226-
# A model can only be either Gemini 1.x or Gemini 2.x, not both
233+
# A model can only be either Gemini 1.x or Gemini 2.0+, not both
227234
if is_gemini_1_model(model):
228-
assert not is_gemini_2_model(
235+
assert not is_gemini_2_or_above(
229236
model
230-
), f'Model {model} classified as both Gemini 1.x and 2.x'
237+
), f'Model {model} classified as both Gemini 1.x and 2.0+'
231238
assert is_gemini_model(
232239
model
233240
), f'Model {model} is Gemini 1.x but not classified as Gemini'
234241

235-
if is_gemini_2_model(model):
242+
if is_gemini_2_or_above(model):
236243
assert not is_gemini_1_model(
237244
model
238-
), f'Model {model} classified as both Gemini 1.x and 2.x'
245+
), f'Model {model} classified as both Gemini 1.x and 2.0+'
239246
assert is_gemini_model(
240247
model
241-
), f'Model {model} is Gemini 2.x but not classified as Gemini'
248+
), f'Model {model} is Gemini 2.0+ but not classified as Gemini'
242249

243-
# If it's neither Gemini 1.x nor 2.x, it should not be classified as Gemini
244-
if not is_gemini_1_model(model) and not is_gemini_2_model(model):
250+
# If it's neither Gemini 1.x nor 2.0+, it should not be classified as Gemini
251+
if not is_gemini_1_model(model) and not is_gemini_2_or_above(model):
245252
if model and 'gemini-' not in extract_model_name(model):
246253
assert not is_gemini_model(
247254
model
@@ -262,6 +269,10 @@ def test_path_vs_simple_model_consistency(self):
262269
'gemini-2.5-pro',
263270
'projects/123/locations/us-central1/publishers/google/models/gemini-2.5-pro',
264271
),
272+
(
273+
'gemini-3.0-pro',
274+
'projects/123/locations/us-central1/publishers/google/models/gemini-3.0-pro',
275+
),
265276
(
266277
'claude-3-sonnet',
267278
'projects/123/locations/us-central1/publishers/google/models/claude-3-sonnet',
@@ -278,7 +289,9 @@ def test_path_vs_simple_model_consistency(self):
278289
f'Inconsistent Gemini 1.x classification for {simple_model} vs'
279290
f' {path_model}'
280291
)
281-
assert is_gemini_2_model(simple_model) == is_gemini_2_model(path_model), (
282-
f'Inconsistent Gemini 2.x classification for {simple_model} vs'
292+
assert is_gemini_2_or_above(simple_model) == is_gemini_2_or_above(
293+
path_model
294+
), (
295+
f'Inconsistent Gemini 2.0+ classification for {simple_model} vs'
283296
f' {path_model}'
284297
)

0 commit comments

Comments
 (0)