From 123ae9954b66db1b51e96aa8c888ee3a3931ba3d Mon Sep 17 00:00:00 2001 From: t-miyak Date: Sat, 1 Nov 2025 17:31:57 +0900 Subject: [PATCH 1/4] feat: reflect pydantic field metadata (description and required) in function declarations --- .../tools/_function_parameter_parse_util.py | 14 +++- .../tools/test_build_function_declaration.py | 69 +++++++++++++++---- 2 files changed, 67 insertions(+), 16 deletions(-) diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index a0168fbe21..2402388d4e 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -289,7 +289,7 @@ def _parse_schema_from_parameter( schema.type = types.Type.OBJECT schema.properties = {} for field_name, field_info in param.annotation.model_fields.items(): - schema.properties[field_name] = _parse_schema_from_parameter( + field_schema = _parse_schema_from_parameter( variant, inspect.Parameter( field_name, @@ -298,6 +298,18 @@ def _parse_schema_from_parameter( ), func_name, ) + + if field_info.description: + field_schema.description = field_info.description + + schema.properties[field_name] = field_schema + + schema.required = [ + field_name + for field_name, field_info in param.annotation.model_fields.items() + if field_info.is_required() + ] + _raise_if_schema_unsupported(variant, schema) return schema if param.annotation is None: diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index edf3c7128e..af01a4504e 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -14,6 +14,7 @@ from typing import Dict from typing import List +from typing import Optional from google.adk.tools import _automatic_function_calling_util from google.adk.tools.tool_context import ToolContext @@ -22,6 +23,7 @@ # TODO: crewai requires python 3.10 as minimum # from crewai_tools import FileReadTool from pydantic import BaseModel +from pydantic import Field def test_string_input(): @@ -152,34 +154,71 @@ class SimpleFunction(BaseModel): def test_nested_basemodel_input(): - class ChildInput(BaseModel): - input_str: str - - class CustomInput(BaseModel): - child: ChildInput + """Test nested Pydantic models with and without Field annotations.""" - def simple_function(input: CustomInput) -> str: + class ChildInput(BaseModel): + name: str = Field(description='The name of the child') + age: int # No Field annotation + nickname: Optional[str] = Field( + default=None, description='Optional nickname' + ) + + class ParentInput(BaseModel): + title: str = Field(description='The title of the parent') + basic_field: str # No Field annotation + child: ChildInput = Field(description='Child information') + optional_field: Optional[str] = Field( + default='default_value', description='An optional field with default' + ) + + def simple_function(input: ParentInput) -> str: return {'result': input} function_decl = _automatic_function_calling_util.build_function_declaration( func=simple_function ) + # Check top-level structure assert function_decl.name == 'simple_function' assert function_decl.parameters.type == 'OBJECT' assert function_decl.parameters.properties['input'].type == 'OBJECT' + + # Check ParentInput properties with and without Field annotations + parent_props = function_decl.parameters.properties['input'].properties + assert parent_props['title'].type == 'STRING' + assert parent_props['title'].description == 'The title of the parent' + assert parent_props['basic_field'].type == 'STRING' + assert parent_props['basic_field'].description is None # No Field annotation + assert parent_props['child'].type == 'OBJECT' + assert parent_props['child'].description == 'Child information' + assert parent_props['optional_field'].type == 'STRING' assert ( - function_decl.parameters.properties['input'].properties['child'].type - == 'OBJECT' - ) - assert ( - function_decl.parameters.properties['input'] - .properties['child'] - .properties['input_str'] - .type - == 'STRING' + parent_props['optional_field'].description + == 'An optional field with default' ) + # Check ParentInput required fields + parent_required = function_decl.parameters.properties['input'].required + assert 'title' in parent_required + assert 'basic_field' in parent_required + assert 'child' in parent_required + assert 'optional_field' not in parent_required # Has default value + + # Check ChildInput properties with and without Field annotations + child_props = parent_props['child'].properties + assert child_props['name'].type == 'STRING' + assert child_props['name'].description == 'The name of the child' + assert child_props['age'].type == 'INTEGER' + assert child_props['age'].description is None # No Field annotation + assert child_props['nickname'].type == 'STRING' + assert child_props['nickname'].description == 'Optional nickname' + + # Check ChildInput required fields + child_required = parent_props['child'].required + assert 'name' in child_required + assert 'age' in child_required + assert 'nickname' not in child_required # Optional with default None + def test_basemodel_with_nested_basemodel(): class ChildInput(BaseModel): From 770959569d33d69d5414c0d133e50bd6277e54ca Mon Sep 17 00:00:00 2001 From: t-miyak Date: Sat, 1 Nov 2025 17:42:32 +0900 Subject: [PATCH 2/4] test: add test pattern of optional field without Field metadata --- tests/unittests/tools/test_build_function_declaration.py | 8 ++++++++ 1 file changed, 8 insertions(+) diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index 192340a16f..b3673d41b1 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -164,6 +164,7 @@ class ChildInput(BaseModel): nickname: Optional[str] = Field( default=None, description='Optional nickname' ) + email: Optional[str] = None # No Field annotation, Optional with default class ParentInput(BaseModel): title: str = Field(description='The title of the parent') @@ -172,6 +173,7 @@ class ParentInput(BaseModel): optional_field: Optional[str] = Field( default='default_value', description='An optional field with default' ) + status: Optional[str] = None # No Field annotation, Optional with default def simple_function(input: ParentInput) -> str: return {'result': input} @@ -198,6 +200,8 @@ def simple_function(input: ParentInput) -> str: parent_props['optional_field'].description == 'An optional field with default' ) + assert parent_props['status'].type == 'STRING' + assert parent_props['status'].description is None # No Field annotation # Check ParentInput required fields parent_required = function_decl.parameters.properties['input'].required @@ -205,6 +209,7 @@ def simple_function(input: ParentInput) -> str: assert 'basic_field' in parent_required assert 'child' in parent_required assert 'optional_field' not in parent_required # Has default value + assert 'status' not in parent_required # No Field annotation, Optional with default # Check ChildInput properties with and without Field annotations child_props = parent_props['child'].properties @@ -214,12 +219,15 @@ def simple_function(input: ParentInput) -> str: assert child_props['age'].description is None # No Field annotation assert child_props['nickname'].type == 'STRING' assert child_props['nickname'].description == 'Optional nickname' + assert child_props['email'].type == 'STRING' + assert child_props['email'].description is None # No Field annotation # Check ChildInput required fields child_required = parent_props['child'].required assert 'name' in child_required assert 'age' in child_required assert 'nickname' not in child_required # Optional with default None + assert 'email' not in child_required # No Field annotation, Optional with default def test_basemodel_with_nested_basemodel(): From 19db89b6856b635fb11dacbe2042b9c5c798632e Mon Sep 17 00:00:00 2001 From: t-miyak Date: Sat, 1 Nov 2025 19:49:56 +0900 Subject: [PATCH 3/4] perf: addressed review comment --- .../tools/_function_parameter_parse_util.py | 547 +++++++++--------- 1 file changed, 267 insertions(+), 280 deletions(-) diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index b15949a2f6..bd1c92ccc1 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -15,20 +15,16 @@ from __future__ import annotations -from enum import Enum import inspect import logging import types as typing_types -from typing import _GenericAlias -from typing import Any -from typing import get_args -from typing import get_origin -from typing import Literal -from typing import Union +from enum import Enum +from typing import Any, Literal, Union, _GenericAlias, get_args, get_origin -from google.genai import types import pydantic +from google.genai import types + from ..utils.variant_utils import GoogleLLMVariant _py_builtin_type_to_schema_type = { @@ -45,308 +41,299 @@ Any: None, } -logger = logging.getLogger('google_adk.' + __name__) +logger = logging.getLogger("google_adk." + __name__) def _is_builtin_primitive_or_compound( annotation: inspect.Parameter.annotation, ) -> bool: - return annotation in _py_builtin_type_to_schema_type.keys() + return annotation in _py_builtin_type_to_schema_type.keys() def _raise_for_any_of_if_mldev(schema: types.Schema): - if schema.any_of: - raise ValueError( - 'AnyOf is not supported in function declaration schema for Google AI.' - ) + if schema.any_of: + raise ValueError( + "AnyOf is not supported in function declaration schema for Google AI." + ) def _update_for_default_if_mldev(schema: types.Schema): - if schema.default is not None: - # TODO(kech): Remove this workaround once mldev supports default value. - schema.default = None - logger.warning( - 'Default value is not supported in function declaration schema for' - ' Google AI.' - ) + if schema.default is not None: + # TODO(kech): Remove this workaround once mldev supports default value. + schema.default = None + logger.warning( + "Default value is not supported in function declaration schema for" + " Google AI." + ) -def _raise_if_schema_unsupported( - variant: GoogleLLMVariant, schema: types.Schema -): - if variant == GoogleLLMVariant.GEMINI_API: - _raise_for_any_of_if_mldev(schema) - # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value +def _raise_if_schema_unsupported(variant: GoogleLLMVariant, schema: types.Schema): + if variant == GoogleLLMVariant.GEMINI_API: + _raise_for_any_of_if_mldev(schema) + # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value def _is_default_value_compatible( default_value: Any, annotation: inspect.Parameter.annotation ) -> bool: - # None type is expected to be handled external to this function - if _is_builtin_primitive_or_compound(annotation): - return isinstance(default_value, annotation) - - if ( - isinstance(annotation, _GenericAlias) - or isinstance(annotation, typing_types.GenericAlias) - or isinstance(annotation, typing_types.UnionType) - ): - origin = get_origin(annotation) - if origin in (Union, typing_types.UnionType): - return any( - _is_default_value_compatible(default_value, arg) - for arg in get_args(annotation) - ) - - if origin is dict: - return isinstance(default_value, dict) - - if origin is list: - if not isinstance(default_value, list): - return False - # most tricky case, element in list is union type - # need to apply any logic within all - # see test case test_generic_alias_complex_array_with_default_value - # a: typing.List[int | str | float | bool] - # default_value: [1, 'a', 1.1, True] - return all( - any( - _is_default_value_compatible(item, arg) - for arg in get_args(annotation) - ) - for item in default_value - ) - - if origin is Literal: - return default_value in get_args(annotation) - - # return False for any other unrecognized annotation - # let caller handle the raise - return False + # None type is expected to be handled external to this function + if _is_builtin_primitive_or_compound(annotation): + return isinstance(default_value, annotation) + + if ( + isinstance(annotation, _GenericAlias) + or isinstance(annotation, typing_types.GenericAlias) + or isinstance(annotation, typing_types.UnionType) + ): + origin = get_origin(annotation) + if origin in (Union, typing_types.UnionType): + return any( + _is_default_value_compatible(default_value, arg) + for arg in get_args(annotation) + ) + + if origin is dict: + return isinstance(default_value, dict) + + if origin is list: + if not isinstance(default_value, list): + return False + # most tricky case, element in list is union type + # need to apply any logic within all + # see test case test_generic_alias_complex_array_with_default_value + # a: typing.List[int | str | float | bool] + # default_value: [1, 'a', 1.1, True] + return all( + any( + _is_default_value_compatible(item, arg) + for arg in get_args(annotation) + ) + for item in default_value + ) + + if origin is Literal: + return default_value in get_args(annotation) + + # return False for any other unrecognized annotation + # let caller handle the raise + return False def _parse_schema_from_parameter( variant: GoogleLLMVariant, param: inspect.Parameter, func_name: str ) -> types.Schema: - """parse schema from parameter. - - from the simplest case to the most complex case. - """ - schema = types.Schema() - default_value_error_msg = ( - f'Default value {param.default} of parameter {param} of function' - f' {func_name} is not compatible with the parameter annotation' - f' {param.annotation}.' - ) - if _is_builtin_primitive_or_compound(param.annotation): - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - schema.type = _py_builtin_type_to_schema_type[param.annotation] - _raise_if_schema_unsupported(variant, schema) - return schema - if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): - schema.type = types.Type.STRING - schema.enum = [e.value for e in param.annotation] - if param.default is not inspect.Parameter.empty: - default_value = ( - param.default.value - if isinstance(param.default, Enum) - else param.default - ) - if default_value not in schema.enum: - raise ValueError(default_value_error_msg) - schema.default = default_value - _raise_if_schema_unsupported(variant, schema) - return schema - if ( - get_origin(param.annotation) is Union - # only parse simple UnionType, example int | str | float | bool - # complex types.UnionType will be invoked in raise branch - and all( - (_is_builtin_primitive_or_compound(arg) or arg is type(None)) - for arg in get_args(param.annotation) - ) - ): - schema.type = types.Type.OBJECT - schema.any_of = [] - unique_types = set() - for arg in get_args(param.annotation): - if arg.__name__ == 'NoneType': # Optional type - schema.nullable = True - continue - schema_in_any_of = _parse_schema_from_parameter( - variant, - inspect.Parameter( - 'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg - ), - func_name, - ) - if ( - schema_in_any_of.model_dump_json(exclude_none=True) - not in unique_types - ): - schema.any_of.append(schema_in_any_of) - unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) - if len(schema.any_of) == 1: # param: list | None -> Array - schema.type = schema.any_of[0].type - schema.any_of = None + """parse schema from parameter. + + from the simplest case to the most complex case. + """ + schema = types.Schema() + default_value_error_msg = ( + f"Default value {param.default} of parameter {param} of function" + f" {func_name} is not compatible with the parameter annotation" + f" {param.annotation}." + ) + if _is_builtin_primitive_or_compound(param.annotation): + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + schema.type = _py_builtin_type_to_schema_type[param.annotation] + _raise_if_schema_unsupported(variant, schema) + return schema + if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): + schema.type = types.Type.STRING + schema.enum = [e.value for e in param.annotation] + if param.default is not inspect.Parameter.empty: + default_value = ( + param.default.value + if isinstance(param.default, Enum) + else param.default + ) + if default_value not in schema.enum: + raise ValueError(default_value_error_msg) + schema.default = default_value + _raise_if_schema_unsupported(variant, schema) + return schema if ( - param.default is not inspect.Parameter.empty - and param.default is not None - ): - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if isinstance(param.annotation, _GenericAlias) or isinstance( - param.annotation, typing_types.GenericAlias - ): - origin = get_origin(param.annotation) - args = get_args(param.annotation) - if origin is dict: - schema.type = types.Type.OBJECT - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if origin is Literal: - if not all(isinstance(arg, str) for arg in args): - raise ValueError( - f'Literal type {param.annotation} must be a list of strings.' - ) - schema.type = types.Type.STRING - schema.enum = list(args) - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if origin is list: - schema.type = types.Type.ARRAY - schema.items = _parse_schema_from_parameter( - variant, - inspect.Parameter( - 'item', - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=args[0], - ), - func_name, - ) - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if origin is Union: - schema.any_of = [] - schema.type = types.Type.OBJECT - unique_types = set() - for arg in args: - if arg.__name__ == 'NoneType': # Optional type - schema.nullable = True - continue - schema_in_any_of = _parse_schema_from_parameter( - variant, - inspect.Parameter( - 'item', - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=arg, - ), - func_name, + get_origin(param.annotation) is Union + # only parse simple UnionType, example int | str | float | bool + # complex types.UnionType will be invoked in raise branch + and all( + (_is_builtin_primitive_or_compound(arg) or arg is type(None)) + for arg in get_args(param.annotation) ) - if ( - len(param.annotation.__args__) == 2 - and type(None) in param.annotation.__args__ - ): # Optional type - for optional_arg in param.annotation.__args__: + ): + schema.type = types.Type.OBJECT + schema.any_of = [] + unique_types = set() + for arg in get_args(param.annotation): + if arg.__name__ == "NoneType": # Optional type + schema.nullable = True + continue + schema_in_any_of = _parse_schema_from_parameter( + variant, + inspect.Parameter( + "item", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg + ), + func_name, + ) + if schema_in_any_of.model_dump_json(exclude_none=True) not in unique_types: + schema.any_of.append(schema_in_any_of) + unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) + if len(schema.any_of) == 1: # param: list | None -> Array + schema.type = schema.any_of[0].type + schema.any_of = None + if param.default is not inspect.Parameter.empty and param.default is not None: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if isinstance(param.annotation, _GenericAlias) or isinstance( + param.annotation, typing_types.GenericAlias + ): + origin = get_origin(param.annotation) + args = get_args(param.annotation) + if origin is dict: + schema.type = types.Type.OBJECT + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is Literal: + if not all(isinstance(arg, str) for arg in args): + raise ValueError( + f"Literal type {param.annotation} must be a list of strings." + ) + schema.type = types.Type.STRING + schema.enum = list(args) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is list: + schema.type = types.Type.ARRAY + schema.items = _parse_schema_from_parameter( + variant, + inspect.Parameter( + "item", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=args[0], + ), + func_name, + ) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is Union: + schema.any_of = [] + schema.type = types.Type.OBJECT + unique_types = set() + for arg in args: + if arg.__name__ == "NoneType": # Optional type + schema.nullable = True + continue + schema_in_any_of = _parse_schema_from_parameter( + variant, + inspect.Parameter( + "item", + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=arg, + ), + func_name, + ) + if ( + len(param.annotation.__args__) == 2 + and type(None) in param.annotation.__args__ + ): # Optional type + for optional_arg in param.annotation.__args__: + if ( + hasattr(optional_arg, "__origin__") + and optional_arg.__origin__ is list + ): + # Optional type with list, for example Optional[list[str]] + schema.items = schema_in_any_of.items + if ( + schema_in_any_of.model_dump_json(exclude_none=True) + not in unique_types + ): + schema.any_of.append(schema_in_any_of) + unique_types.add( + schema_in_any_of.model_dump_json(exclude_none=True) + ) + if len(schema.any_of) == 1: # param: Union[List, None] -> Array + schema.type = schema.any_of[0].type + schema.any_of = None if ( - hasattr(optional_arg, '__origin__') - and optional_arg.__origin__ is list + param.default is not None + and param.default is not inspect.Parameter.empty ): - # Optional type with list, for example Optional[list[str]] - schema.items = schema_in_any_of.items - if ( - schema_in_any_of.model_dump_json(exclude_none=True) - not in unique_types - ): - schema.any_of.append(schema_in_any_of) - unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) - if len(schema.any_of) == 1: # param: Union[List, None] -> Array - schema.type = schema.any_of[0].type - schema.any_of = None - if ( - param.default is not None - and param.default is not inspect.Parameter.empty - ): - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - # all other generic alias will be invoked in raise branch - if ( - inspect.isclass(param.annotation) - # for user defined class, we only support pydantic model - and issubclass(param.annotation, pydantic.BaseModel) - ): + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + # all other generic alias will be invoked in raise branch if ( - param.default is not inspect.Parameter.empty - and param.default is not None + inspect.isclass(param.annotation) + # for user defined class, we only support pydantic model + and issubclass(param.annotation, pydantic.BaseModel) ): - schema.default = param.default - schema.type = types.Type.OBJECT - schema.properties = {} - for field_name, field_info in param.annotation.model_fields.items(): - field_schema = _parse_schema_from_parameter( - variant, - inspect.Parameter( - field_name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=field_info.annotation, - ), - func_name, - ) - - if field_info.description: - field_schema.description = field_info.description - - schema.properties[field_name] = field_schema - - schema.required = [ - field_name - for field_name, field_info in param.annotation.model_fields.items() - if field_info.is_required() - ] + if param.default is not inspect.Parameter.empty and param.default is not None: + schema.default = param.default + schema.type = types.Type.OBJECT + schema.properties = {} + required_fields = [] + for field_name, field_info in param.annotation.model_fields.items(): + field_schema = _parse_schema_from_parameter( + variant, + inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=field_info.annotation, + ), + func_name, + ) - _raise_if_schema_unsupported(variant, schema) - return schema - if param.annotation is None: - # https://swagger.io/docs/specification/v3_0/data-models/data-types/#null - # null is not a valid type in schema, use object instead. - schema.type = types.Type.OBJECT - schema.nullable = True - _raise_if_schema_unsupported(variant, schema) - return schema - raise ValueError( - f'Failed to parse the parameter {param} of function {func_name} for' - ' automatic function calling. Automatic function calling works best with' - ' simpler function signature schema, consider manually parsing your' - f' function declaration for function {func_name}.' - ) + if field_info.description: + field_schema.description = field_info.description + + if field_info.is_required(): + required_fields.append(field_name) + + schema.properties[field_name] = field_schema + + schema.required = required_fields + + _raise_if_schema_unsupported(variant, schema) + return schema + if param.annotation is None: + # https://swagger.io/docs/specification/v3_0/data-models/data-types/#null + # null is not a valid type in schema, use object instead. + schema.type = types.Type.OBJECT + schema.nullable = True + _raise_if_schema_unsupported(variant, schema) + return schema + raise ValueError( + f"Failed to parse the parameter {param} of function {func_name} for" + " automatic function calling. Automatic function calling works best with" + " simpler function signature schema, consider manually parsing your" + f" function declaration for function {func_name}." + ) def _get_required_fields(schema: types.Schema) -> list[str]: - if not schema.properties: - return - return [ - field_name - for field_name, field_schema in schema.properties.items() - if not field_schema.nullable and field_schema.default is None - ] + if not schema.properties: + return + return [ + field_name + for field_name, field_schema in schema.properties.items() + if not field_schema.nullable and field_schema.default is None + ] From ad8a0e44c12b12537eefac4f2348acfbbbbfba43 Mon Sep 17 00:00:00 2001 From: t-miyak Date: Sat, 1 Nov 2025 19:55:19 +0900 Subject: [PATCH 4/4] fix: reformat --- .../tools/_function_parameter_parse_util.py | 549 +++++++++--------- .../tools/test_build_function_declaration.py | 8 +- 2 files changed, 287 insertions(+), 270 deletions(-) diff --git a/src/google/adk/tools/_function_parameter_parse_util.py b/src/google/adk/tools/_function_parameter_parse_util.py index bd1c92ccc1..108b07a120 100644 --- a/src/google/adk/tools/_function_parameter_parse_util.py +++ b/src/google/adk/tools/_function_parameter_parse_util.py @@ -15,15 +15,19 @@ from __future__ import annotations +from enum import Enum import inspect import logging import types as typing_types -from enum import Enum -from typing import Any, Literal, Union, _GenericAlias, get_args, get_origin - -import pydantic +from typing import _GenericAlias +from typing import Any +from typing import get_args +from typing import get_origin +from typing import Literal +from typing import Union from google.genai import types +import pydantic from ..utils.variant_utils import GoogleLLMVariant @@ -41,299 +45,308 @@ Any: None, } -logger = logging.getLogger("google_adk." + __name__) +logger = logging.getLogger('google_adk.' + __name__) def _is_builtin_primitive_or_compound( annotation: inspect.Parameter.annotation, ) -> bool: - return annotation in _py_builtin_type_to_schema_type.keys() + return annotation in _py_builtin_type_to_schema_type.keys() def _raise_for_any_of_if_mldev(schema: types.Schema): - if schema.any_of: - raise ValueError( - "AnyOf is not supported in function declaration schema for Google AI." - ) + if schema.any_of: + raise ValueError( + 'AnyOf is not supported in function declaration schema for Google AI.' + ) def _update_for_default_if_mldev(schema: types.Schema): - if schema.default is not None: - # TODO(kech): Remove this workaround once mldev supports default value. - schema.default = None - logger.warning( - "Default value is not supported in function declaration schema for" - " Google AI." - ) + if schema.default is not None: + # TODO(kech): Remove this workaround once mldev supports default value. + schema.default = None + logger.warning( + 'Default value is not supported in function declaration schema for' + ' Google AI.' + ) -def _raise_if_schema_unsupported(variant: GoogleLLMVariant, schema: types.Schema): - if variant == GoogleLLMVariant.GEMINI_API: - _raise_for_any_of_if_mldev(schema) - # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value +def _raise_if_schema_unsupported( + variant: GoogleLLMVariant, schema: types.Schema +): + if variant == GoogleLLMVariant.GEMINI_API: + _raise_for_any_of_if_mldev(schema) + # _update_for_default_if_mldev(schema) # No need of this since GEMINI now supports default value def _is_default_value_compatible( default_value: Any, annotation: inspect.Parameter.annotation ) -> bool: - # None type is expected to be handled external to this function - if _is_builtin_primitive_or_compound(annotation): - return isinstance(default_value, annotation) - - if ( - isinstance(annotation, _GenericAlias) - or isinstance(annotation, typing_types.GenericAlias) - or isinstance(annotation, typing_types.UnionType) - ): - origin = get_origin(annotation) - if origin in (Union, typing_types.UnionType): - return any( - _is_default_value_compatible(default_value, arg) - for arg in get_args(annotation) - ) - - if origin is dict: - return isinstance(default_value, dict) - - if origin is list: - if not isinstance(default_value, list): - return False - # most tricky case, element in list is union type - # need to apply any logic within all - # see test case test_generic_alias_complex_array_with_default_value - # a: typing.List[int | str | float | bool] - # default_value: [1, 'a', 1.1, True] - return all( - any( - _is_default_value_compatible(item, arg) - for arg in get_args(annotation) - ) - for item in default_value - ) - - if origin is Literal: - return default_value in get_args(annotation) - - # return False for any other unrecognized annotation - # let caller handle the raise - return False + # None type is expected to be handled external to this function + if _is_builtin_primitive_or_compound(annotation): + return isinstance(default_value, annotation) + + if ( + isinstance(annotation, _GenericAlias) + or isinstance(annotation, typing_types.GenericAlias) + or isinstance(annotation, typing_types.UnionType) + ): + origin = get_origin(annotation) + if origin in (Union, typing_types.UnionType): + return any( + _is_default_value_compatible(default_value, arg) + for arg in get_args(annotation) + ) + + if origin is dict: + return isinstance(default_value, dict) + + if origin is list: + if not isinstance(default_value, list): + return False + # most tricky case, element in list is union type + # need to apply any logic within all + # see test case test_generic_alias_complex_array_with_default_value + # a: typing.List[int | str | float | bool] + # default_value: [1, 'a', 1.1, True] + return all( + any( + _is_default_value_compatible(item, arg) + for arg in get_args(annotation) + ) + for item in default_value + ) + + if origin is Literal: + return default_value in get_args(annotation) + + # return False for any other unrecognized annotation + # let caller handle the raise + return False def _parse_schema_from_parameter( variant: GoogleLLMVariant, param: inspect.Parameter, func_name: str ) -> types.Schema: - """parse schema from parameter. - - from the simplest case to the most complex case. - """ - schema = types.Schema() - default_value_error_msg = ( - f"Default value {param.default} of parameter {param} of function" - f" {func_name} is not compatible with the parameter annotation" - f" {param.annotation}." - ) - if _is_builtin_primitive_or_compound(param.annotation): - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - schema.type = _py_builtin_type_to_schema_type[param.annotation] - _raise_if_schema_unsupported(variant, schema) - return schema - if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): - schema.type = types.Type.STRING - schema.enum = [e.value for e in param.annotation] - if param.default is not inspect.Parameter.empty: - default_value = ( - param.default.value - if isinstance(param.default, Enum) - else param.default - ) - if default_value not in schema.enum: - raise ValueError(default_value_error_msg) - schema.default = default_value - _raise_if_schema_unsupported(variant, schema) - return schema + """parse schema from parameter. + + from the simplest case to the most complex case. + """ + schema = types.Schema() + default_value_error_msg = ( + f'Default value {param.default} of parameter {param} of function' + f' {func_name} is not compatible with the parameter annotation' + f' {param.annotation}.' + ) + if _is_builtin_primitive_or_compound(param.annotation): + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + schema.type = _py_builtin_type_to_schema_type[param.annotation] + _raise_if_schema_unsupported(variant, schema) + return schema + if isinstance(param.annotation, type) and issubclass(param.annotation, Enum): + schema.type = types.Type.STRING + schema.enum = [e.value for e in param.annotation] + if param.default is not inspect.Parameter.empty: + default_value = ( + param.default.value + if isinstance(param.default, Enum) + else param.default + ) + if default_value not in schema.enum: + raise ValueError(default_value_error_msg) + schema.default = default_value + _raise_if_schema_unsupported(variant, schema) + return schema + if ( + get_origin(param.annotation) is Union + # only parse simple UnionType, example int | str | float | bool + # complex types.UnionType will be invoked in raise branch + and all( + (_is_builtin_primitive_or_compound(arg) or arg is type(None)) + for arg in get_args(param.annotation) + ) + ): + schema.type = types.Type.OBJECT + schema.any_of = [] + unique_types = set() + for arg in get_args(param.annotation): + if arg.__name__ == 'NoneType': # Optional type + schema.nullable = True + continue + schema_in_any_of = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg + ), + func_name, + ) + if ( + schema_in_any_of.model_dump_json(exclude_none=True) + not in unique_types + ): + schema.any_of.append(schema_in_any_of) + unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) + if len(schema.any_of) == 1: # param: list | None -> Array + schema.type = schema.any_of[0].type + schema.any_of = None if ( - get_origin(param.annotation) is Union - # only parse simple UnionType, example int | str | float | bool - # complex types.UnionType will be invoked in raise branch - and all( - (_is_builtin_primitive_or_compound(arg) or arg is type(None)) - for arg in get_args(param.annotation) - ) + param.default is not inspect.Parameter.empty + and param.default is not None ): - schema.type = types.Type.OBJECT - schema.any_of = [] - unique_types = set() - for arg in get_args(param.annotation): - if arg.__name__ == "NoneType": # Optional type - schema.nullable = True - continue - schema_in_any_of = _parse_schema_from_parameter( - variant, - inspect.Parameter( - "item", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=arg - ), - func_name, - ) - if schema_in_any_of.model_dump_json(exclude_none=True) not in unique_types: - schema.any_of.append(schema_in_any_of) - unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) - if len(schema.any_of) == 1: # param: list | None -> Array - schema.type = schema.any_of[0].type - schema.any_of = None - if param.default is not inspect.Parameter.empty and param.default is not None: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if isinstance(param.annotation, _GenericAlias) or isinstance( - param.annotation, typing_types.GenericAlias - ): - origin = get_origin(param.annotation) - args = get_args(param.annotation) - if origin is dict: - schema.type = types.Type.OBJECT - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if origin is Literal: - if not all(isinstance(arg, str) for arg in args): - raise ValueError( - f"Literal type {param.annotation} must be a list of strings." - ) - schema.type = types.Type.STRING - schema.enum = list(args) - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if origin is list: - schema.type = types.Type.ARRAY - schema.items = _parse_schema_from_parameter( - variant, - inspect.Parameter( - "item", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=args[0], - ), - func_name, - ) - if param.default is not inspect.Parameter.empty: - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - if origin is Union: - schema.any_of = [] - schema.type = types.Type.OBJECT - unique_types = set() - for arg in args: - if arg.__name__ == "NoneType": # Optional type - schema.nullable = True - continue - schema_in_any_of = _parse_schema_from_parameter( - variant, - inspect.Parameter( - "item", - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=arg, - ), - func_name, - ) - if ( - len(param.annotation.__args__) == 2 - and type(None) in param.annotation.__args__ - ): # Optional type - for optional_arg in param.annotation.__args__: - if ( - hasattr(optional_arg, "__origin__") - and optional_arg.__origin__ is list - ): - # Optional type with list, for example Optional[list[str]] - schema.items = schema_in_any_of.items - if ( - schema_in_any_of.model_dump_json(exclude_none=True) - not in unique_types - ): - schema.any_of.append(schema_in_any_of) - unique_types.add( - schema_in_any_of.model_dump_json(exclude_none=True) - ) - if len(schema.any_of) == 1: # param: Union[List, None] -> Array - schema.type = schema.any_of[0].type - schema.any_of = None + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if isinstance(param.annotation, _GenericAlias) or isinstance( + param.annotation, typing_types.GenericAlias + ): + origin = get_origin(param.annotation) + args = get_args(param.annotation) + if origin is dict: + schema.type = types.Type.OBJECT + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is Literal: + if not all(isinstance(arg, str) for arg in args): + raise ValueError( + f'Literal type {param.annotation} must be a list of strings.' + ) + schema.type = types.Type.STRING + schema.enum = list(args) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is list: + schema.type = types.Type.ARRAY + schema.items = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=args[0], + ), + func_name, + ) + if param.default is not inspect.Parameter.empty: + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + if origin is Union: + schema.any_of = [] + schema.type = types.Type.OBJECT + unique_types = set() + for arg in args: + if arg.__name__ == 'NoneType': # Optional type + schema.nullable = True + continue + schema_in_any_of = _parse_schema_from_parameter( + variant, + inspect.Parameter( + 'item', + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=arg, + ), + func_name, + ) + if ( + len(param.annotation.__args__) == 2 + and type(None) in param.annotation.__args__ + ): # Optional type + for optional_arg in param.annotation.__args__: if ( - param.default is not None - and param.default is not inspect.Parameter.empty + hasattr(optional_arg, '__origin__') + and optional_arg.__origin__ is list ): - if not _is_default_value_compatible(param.default, param.annotation): - raise ValueError(default_value_error_msg) - schema.default = param.default - _raise_if_schema_unsupported(variant, schema) - return schema - # all other generic alias will be invoked in raise branch + # Optional type with list, for example Optional[list[str]] + schema.items = schema_in_any_of.items + if ( + schema_in_any_of.model_dump_json(exclude_none=True) + not in unique_types + ): + schema.any_of.append(schema_in_any_of) + unique_types.add(schema_in_any_of.model_dump_json(exclude_none=True)) + if len(schema.any_of) == 1: # param: Union[List, None] -> Array + schema.type = schema.any_of[0].type + schema.any_of = None + if ( + param.default is not None + and param.default is not inspect.Parameter.empty + ): + if not _is_default_value_compatible(param.default, param.annotation): + raise ValueError(default_value_error_msg) + schema.default = param.default + _raise_if_schema_unsupported(variant, schema) + return schema + # all other generic alias will be invoked in raise branch + if ( + inspect.isclass(param.annotation) + # for user defined class, we only support pydantic model + and issubclass(param.annotation, pydantic.BaseModel) + ): if ( - inspect.isclass(param.annotation) - # for user defined class, we only support pydantic model - and issubclass(param.annotation, pydantic.BaseModel) + param.default is not inspect.Parameter.empty + and param.default is not None ): - if param.default is not inspect.Parameter.empty and param.default is not None: - schema.default = param.default - schema.type = types.Type.OBJECT - schema.properties = {} - required_fields = [] - for field_name, field_info in param.annotation.model_fields.items(): - field_schema = _parse_schema_from_parameter( - variant, - inspect.Parameter( - field_name, - inspect.Parameter.POSITIONAL_OR_KEYWORD, - annotation=field_info.annotation, - ), - func_name, - ) - - if field_info.description: - field_schema.description = field_info.description - - if field_info.is_required(): - required_fields.append(field_name) - - schema.properties[field_name] = field_schema - - schema.required = required_fields - - _raise_if_schema_unsupported(variant, schema) - return schema - if param.annotation is None: - # https://swagger.io/docs/specification/v3_0/data-models/data-types/#null - # null is not a valid type in schema, use object instead. - schema.type = types.Type.OBJECT - schema.nullable = True - _raise_if_schema_unsupported(variant, schema) - return schema - raise ValueError( - f"Failed to parse the parameter {param} of function {func_name} for" - " automatic function calling. Automatic function calling works best with" - " simpler function signature schema, consider manually parsing your" - f" function declaration for function {func_name}." - ) + schema.default = param.default + schema.type = types.Type.OBJECT + schema.properties = {} + required_fields = [] + for field_name, field_info in param.annotation.model_fields.items(): + field_schema = _parse_schema_from_parameter( + variant, + inspect.Parameter( + field_name, + inspect.Parameter.POSITIONAL_OR_KEYWORD, + annotation=field_info.annotation, + ), + func_name, + ) + + if field_info.description: + field_schema.description = field_info.description + + if field_info.is_required(): + required_fields.append(field_name) + + schema.properties[field_name] = field_schema + + schema.required = required_fields + + _raise_if_schema_unsupported(variant, schema) + return schema + if param.annotation is None: + # https://swagger.io/docs/specification/v3_0/data-models/data-types/#null + # null is not a valid type in schema, use object instead. + schema.type = types.Type.OBJECT + schema.nullable = True + _raise_if_schema_unsupported(variant, schema) + return schema + raise ValueError( + f'Failed to parse the parameter {param} of function {func_name} for' + ' automatic function calling. Automatic function calling works best with' + ' simpler function signature schema, consider manually parsing your' + f' function declaration for function {func_name}.' + ) def _get_required_fields(schema: types.Schema) -> list[str]: - if not schema.properties: - return - return [ - field_name - for field_name, field_schema in schema.properties.items() - if not field_schema.nullable and field_schema.default is None - ] + if not schema.properties: + return + return [ + field_name + for field_name, field_schema in schema.properties.items() + if not field_schema.nullable and field_schema.default is None + ] diff --git a/tests/unittests/tools/test_build_function_declaration.py b/tests/unittests/tools/test_build_function_declaration.py index b3673d41b1..8b526378a8 100644 --- a/tests/unittests/tools/test_build_function_declaration.py +++ b/tests/unittests/tools/test_build_function_declaration.py @@ -209,7 +209,9 @@ def simple_function(input: ParentInput) -> str: assert 'basic_field' in parent_required assert 'child' in parent_required assert 'optional_field' not in parent_required # Has default value - assert 'status' not in parent_required # No Field annotation, Optional with default + assert ( + 'status' not in parent_required + ) # No Field annotation, Optional with default # Check ChildInput properties with and without Field annotations child_props = parent_props['child'].properties @@ -227,7 +229,9 @@ def simple_function(input: ParentInput) -> str: assert 'name' in child_required assert 'age' in child_required assert 'nickname' not in child_required # Optional with default None - assert 'email' not in child_required # No Field annotation, Optional with default + assert ( + 'email' not in child_required + ) # No Field annotation, Optional with default def test_basemodel_with_nested_basemodel():