|
10 | 10 | from unittest.mock import ANY, AsyncMock, Mock, patch |
11 | 11 |
|
12 | 12 | import google.ai.generativelanguage as glm |
| 13 | +import proto # type: ignore[import-untyped] |
13 | 14 | import pytest |
14 | 15 | from google.ai.generativelanguage_v1beta.types import ( |
15 | 16 | Candidate, |
@@ -2957,6 +2958,110 @@ def test_response_schema_mime_type_validation() -> None: |
2957 | 2958 | assert llm_with_json_schema is not None |
2958 | 2959 |
|
2959 | 2960 |
|
| 2961 | +def _convert_proto_to_dict(obj: Any) -> Any: |
| 2962 | + """Recursively convert proto objects to dicts for comparison.""" |
| 2963 | + if isinstance(obj, dict): |
| 2964 | + return {k: _convert_proto_to_dict(v) for k, v in obj.items()} |
| 2965 | + elif isinstance(obj, (list, tuple)): |
| 2966 | + return [_convert_proto_to_dict(item) for item in obj] |
| 2967 | + elif hasattr(obj, "__class__") and "proto" in str(type(obj)): |
| 2968 | + # Try to convert proto object to dict |
| 2969 | + try: |
| 2970 | + if hasattr(obj, "__iter__") and not isinstance(obj, str): |
| 2971 | + converted = dict(obj) |
| 2972 | + # Recursively convert nested proto objects |
| 2973 | + return {k: _convert_proto_to_dict(v) for k, v in converted.items()} |
| 2974 | + else: |
| 2975 | + return obj |
| 2976 | + except (TypeError, ValueError): |
| 2977 | + return obj |
| 2978 | + return obj |
| 2979 | + |
| 2980 | + |
| 2981 | +def test_response_format_provider_strategy() -> None: |
| 2982 | + """Test that `response_format` from ProviderStrategy is correctly handled.""" |
| 2983 | + llm = ChatGoogleGenerativeAI( |
| 2984 | + model=MODEL_NAME, google_api_key=SecretStr(FAKE_API_KEY) |
| 2985 | + ) |
| 2986 | + |
| 2987 | + schema_dict = { |
| 2988 | + "type": "object", |
| 2989 | + "properties": { |
| 2990 | + "sentiment": { |
| 2991 | + "type": "string", |
| 2992 | + "enum": ["positive", "negative", "neutral"], |
| 2993 | + }, |
| 2994 | + "confidence": {"type": "number", "minimum": 0, "maximum": 1}, |
| 2995 | + }, |
| 2996 | + "required": ["sentiment", "confidence"], |
| 2997 | + "additionalProperties": False, |
| 2998 | + } |
| 2999 | + |
| 3000 | + # Test response_format with ProviderStrategy format (OpenAI-style) |
| 3001 | + response_format = { |
| 3002 | + "type": "json_schema", |
| 3003 | + "json_schema": { |
| 3004 | + "name": "response_format_test", |
| 3005 | + "schema": schema_dict, |
| 3006 | + }, |
| 3007 | + } |
| 3008 | + |
| 3009 | + gen_config = llm._prepare_params(stop=None, response_format=response_format) |
| 3010 | + |
| 3011 | + # response_json_schema may be converted to proto object, so convert to dict for comparison |
| 3012 | + schema = _convert_proto_to_dict(gen_config.response_json_schema) |
| 3013 | + assert schema == schema_dict |
| 3014 | + assert gen_config.response_mime_type == "application/json" |
| 3015 | + |
| 3016 | + # Test that response_json_schema takes precedence over response_format |
| 3017 | + different_schema = { |
| 3018 | + "type": "object", |
| 3019 | + "properties": {"age": {"type": "integer"}}, |
| 3020 | + "required": ["age"], |
| 3021 | + } |
| 3022 | + |
| 3023 | + gen_config_2 = llm._prepare_params( |
| 3024 | + stop=None, |
| 3025 | + response_format=response_format, |
| 3026 | + response_json_schema=different_schema, |
| 3027 | + ) |
| 3028 | + |
| 3029 | + # response_json_schema may be converted to proto object, so convert to dict for comparison |
| 3030 | + schema_2 = _convert_proto_to_dict(gen_config_2.response_json_schema) |
| 3031 | + assert schema_2 == different_schema |
| 3032 | + assert gen_config_2.response_mime_type == "application/json" |
| 3033 | + |
| 3034 | + |
| 3035 | + old_schema = { |
| 3036 | + "type": "object", |
| 3037 | + "properties": {"old_field": {"type": "string"}}, |
| 3038 | + "required": ["old_field"], |
| 3039 | + } |
| 3040 | + |
| 3041 | + gen_config_3 = llm._prepare_params( |
| 3042 | + stop=None, |
| 3043 | + response_schema=old_schema, |
| 3044 | + response_format=response_format, |
| 3045 | + ) |
| 3046 | + |
| 3047 | + # response_json_schema may be converted to proto object, so convert to dict for comparison |
| 3048 | + schema_3 = _convert_proto_to_dict(gen_config_3.response_json_schema) |
| 3049 | + assert schema_3 == schema_dict |
| 3050 | + assert gen_config_3.response_mime_type == "application/json" |
| 3051 | + |
| 3052 | + invalid_response_format = {"type": "invalid_type"} |
| 3053 | + gen_config_4 = llm._prepare_params( |
| 3054 | + stop=None, |
| 3055 | + response_format=invalid_response_format, |
| 3056 | + response_schema=schema_dict, |
| 3057 | + ) |
| 3058 | + # Should fall back to response_schema |
| 3059 | + # response_json_schema may be converted to proto object, so convert to dict for comparison |
| 3060 | + schema_4 = _convert_proto_to_dict(gen_config_4.response_json_schema) |
| 3061 | + assert schema_4 == schema_dict |
| 3062 | + assert gen_config_4.response_mime_type == "application/json" |
| 3063 | + |
| 3064 | + |
2960 | 3065 | def test_is_new_gemini_model() -> None: |
2961 | 3066 | assert _is_gemini_3_or_later("gemini-3.0-pro") is True |
2962 | 3067 | assert _is_gemini_3_or_later("gemini-2.5-pro") is False |
|
0 commit comments