Skip to content

Commit 240ef5b

Browse files
SOORAJTS2001Jacksunweiyyyu-google
authored
feat: Added support for enums as arguments for function tools (#3088)
* feat: Added support for enums as arguments for function tools * feat: Add default value support for function tools fix: Add more test cases inside `test_build_function_declaration.py` for passing Enums as arguments * fix: format code with pyink --------- Co-authored-by: Wei Sun (Jack) <weisun@google.com> Co-authored-by: Yvonne Yu <150068659+yyyu-google@users.noreply.github.com>
1 parent b17c8f1 commit 240ef5b

File tree

2 files changed

+46
-1
lines changed

2 files changed

+46
-1
lines changed

src/google/adk/tools/_function_parameter_parse_util.py

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
from __future__ import annotations
1717

18+
from enum import Enum
1819
import inspect
1920
import logging
2021
import types as typing_types
@@ -75,7 +76,7 @@ def _raise_if_schema_unsupported(
7576
):
7677
if variant == GoogleLLMVariant.GEMINI_API:
7778
_raise_for_any_of_if_mldev(schema)
78-
_update_for_default_if_mldev(schema)
79+
# _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value
7980

8081

8182
def _is_default_value_compatible(
@@ -145,6 +146,20 @@ def _parse_schema_from_parameter(
145146
schema.type = _py_builtin_type_to_schema_type[param.annotation]
146147
_raise_if_schema_unsupported(variant, schema)
147148
return schema
149+
if isinstance(param.annotation, type) and issubclass(param.annotation, Enum):
150+
schema.type = types.Type.STRING
151+
schema.enum = [e.value for e in param.annotation]
152+
if param.default is not inspect.Parameter.empty:
153+
default_value = (
154+
param.default.value
155+
if isinstance(param.default, Enum)
156+
else param.default
157+
)
158+
if default_value not in schema.enum:
159+
raise ValueError(default_value_error_msg)
160+
schema.default = default_value
161+
_raise_if_schema_unsupported(variant, schema)
162+
return schema
148163
if (
149164
get_origin(param.annotation) is Union
150165
# only parse simple UnionType, example int | str | float | bool

tests/unittests/tools/test_build_function_declaration.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
from enum import Enum
1516
from typing import Dict
1617
from typing import List
1718

@@ -22,6 +23,7 @@
2223
# TODO: crewai requires python 3.10 as minimum
2324
# from crewai_tools import FileReadTool
2425
from pydantic import BaseModel
26+
import pytest
2527

2628

2729
def test_string_input():
@@ -220,6 +222,34 @@ def simple_function(
220222
assert function_decl.parameters.properties['input_dir'].items.type == 'OBJECT'
221223

222224

225+
def test_enums():
226+
227+
class InputEnum(Enum):
228+
AGENT = 'agent'
229+
TOOL = 'tool'
230+
231+
def simple_function(input: InputEnum = InputEnum.AGENT):
232+
return input.value
233+
234+
function_decl = _automatic_function_calling_util.build_function_declaration(
235+
func=simple_function
236+
)
237+
238+
assert function_decl.name == 'simple_function'
239+
assert function_decl.parameters.type == 'OBJECT'
240+
assert function_decl.parameters.properties['input'].type == 'STRING'
241+
assert function_decl.parameters.properties['input'].default == 'agent'
242+
assert function_decl.parameters.properties['input'].enum == ['agent', 'tool']
243+
244+
def simple_function_with_wrong_enum(input: InputEnum = 'WRONG_ENUM'):
245+
return input.value
246+
247+
with pytest.raises(ValueError):
248+
_automatic_function_calling_util.build_function_declaration(
249+
func=simple_function_with_wrong_enum
250+
)
251+
252+
223253
def test_basemodel_list():
224254
class ChildInput(BaseModel):
225255
input_str: str

0 commit comments

Comments
 (0)