Skip to content

Commit ceff510

Browse files
committed
fix(genai): Add ProviderStrategy response_format support
- Add handling for response_format parameter from ProviderStrategy
1 parent 93bd5e6 commit ceff510

File tree

2 files changed

+126
-3
lines changed

2 files changed

+126
-3
lines changed

libs/genai/langchain_google_genai/chat_models.py

Lines changed: 21 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2220,14 +2220,23 @@ def _prepare_params(
22202220
gen_config = {**gen_config, **generation_config}
22212221

22222222
response_mime_type = kwargs.get("response_mime_type", self.response_mime_type)
2223-
if response_mime_type is not None:
2224-
gen_config["response_mime_type"] = response_mime_type
2225-
22262223
response_schema = kwargs.get("response_schema", self.response_schema)
22272224

22282225
# In case passed in as a direct kwarg
22292226
response_json_schema = kwargs.get("response_json_schema")
22302227

2228+
# Handle response_format from ProviderStrategy (OpenAI-style format)
2229+
# Format: {'type': 'json_schema', 'json_schema': {'schema': {...}}}
2230+
# Only extract from response_format if response_json_schema is not already set
2231+
if response_json_schema is None:
2232+
response_format = kwargs.get("response_format")
2233+
if response_format is not None and isinstance(response_format, dict):
2234+
if response_format.get("type") == "json_schema":
2235+
json_schema_obj = response_format.get("json_schema", {})
2236+
if isinstance(json_schema_obj, dict) and "schema" in json_schema_obj:
2237+
# Extract the actual schema from the nested structure
2238+
response_json_schema = json_schema_obj["schema"]
2239+
22312240
# Handle both response_schema and response_json_schema
22322241
# (Regardless, we use `response_json_schema` in the request)
22332242
schema_to_use = (
@@ -2236,6 +2245,15 @@ def _prepare_params(
22362245
else response_schema
22372246
)
22382247

2248+
# Automatically set response_mime_type to "application/json" when response_schema
2249+
# is provided but response_mime_type is not set. This enables seamless support
2250+
# for structured output strategies like ProviderStrategy.
2251+
if schema_to_use is not None and response_mime_type is None:
2252+
response_mime_type = "application/json"
2253+
2254+
if response_mime_type is not None:
2255+
gen_config["response_mime_type"] = response_mime_type
2256+
22392257
if schema_to_use is not None:
22402258
if response_mime_type != "application/json":
22412259
param_name = (

libs/genai/tests/unit_tests/test_chat_models.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from unittest.mock import ANY, AsyncMock, Mock, patch
1111

1212
import google.ai.generativelanguage as glm
13+
import proto # type: ignore[import-untyped]
1314
import pytest
1415
from google.ai.generativelanguage_v1beta.types import (
1516
Candidate,
@@ -2957,6 +2958,110 @@ def test_response_schema_mime_type_validation() -> None:
29572958
assert llm_with_json_schema is not None
29582959

29592960

2961+
def _convert_proto_to_dict(obj: Any) -> Any:
2962+
"""Recursively convert proto objects to dicts for comparison."""
2963+
if isinstance(obj, dict):
2964+
return {k: _convert_proto_to_dict(v) for k, v in obj.items()}
2965+
elif isinstance(obj, (list, tuple)):
2966+
return [_convert_proto_to_dict(item) for item in obj]
2967+
elif hasattr(obj, "__class__") and "proto" in str(type(obj)):
2968+
# Try to convert proto object to dict
2969+
try:
2970+
if hasattr(obj, "__iter__") and not isinstance(obj, str):
2971+
converted = dict(obj)
2972+
# Recursively convert nested proto objects
2973+
return {k: _convert_proto_to_dict(v) for k, v in converted.items()}
2974+
else:
2975+
return obj
2976+
except (TypeError, ValueError):
2977+
return obj
2978+
return obj
2979+
2980+
2981+
def test_response_format_provider_strategy() -> None:
2982+
"""Test that `response_format` from ProviderStrategy is correctly handled."""
2983+
llm = ChatGoogleGenerativeAI(
2984+
model=MODEL_NAME, google_api_key=SecretStr(FAKE_API_KEY)
2985+
)
2986+
2987+
schema_dict = {
2988+
"type": "object",
2989+
"properties": {
2990+
"sentiment": {
2991+
"type": "string",
2992+
"enum": ["positive", "negative", "neutral"],
2993+
},
2994+
"confidence": {"type": "number", "minimum": 0, "maximum": 1},
2995+
},
2996+
"required": ["sentiment", "confidence"],
2997+
"additionalProperties": False,
2998+
}
2999+
3000+
# Test response_format with ProviderStrategy format (OpenAI-style)
3001+
response_format = {
3002+
"type": "json_schema",
3003+
"json_schema": {
3004+
"name": "response_format_test",
3005+
"schema": schema_dict,
3006+
},
3007+
}
3008+
3009+
gen_config = llm._prepare_params(stop=None, response_format=response_format)
3010+
3011+
# response_json_schema may be converted to proto object, so convert to dict for comparison
3012+
schema = _convert_proto_to_dict(gen_config.response_json_schema)
3013+
assert schema == schema_dict
3014+
assert gen_config.response_mime_type == "application/json"
3015+
3016+
# Test that response_json_schema takes precedence over response_format
3017+
different_schema = {
3018+
"type": "object",
3019+
"properties": {"age": {"type": "integer"}},
3020+
"required": ["age"],
3021+
}
3022+
3023+
gen_config_2 = llm._prepare_params(
3024+
stop=None,
3025+
response_format=response_format,
3026+
response_json_schema=different_schema,
3027+
)
3028+
3029+
# response_json_schema may be converted to proto object, so convert to dict for comparison
3030+
schema_2 = _convert_proto_to_dict(gen_config_2.response_json_schema)
3031+
assert schema_2 == different_schema
3032+
assert gen_config_2.response_mime_type == "application/json"
3033+
3034+
3035+
old_schema = {
3036+
"type": "object",
3037+
"properties": {"old_field": {"type": "string"}},
3038+
"required": ["old_field"],
3039+
}
3040+
3041+
gen_config_3 = llm._prepare_params(
3042+
stop=None,
3043+
response_schema=old_schema,
3044+
response_format=response_format,
3045+
)
3046+
3047+
# response_json_schema may be converted to proto object, so convert to dict for comparison
3048+
schema_3 = _convert_proto_to_dict(gen_config_3.response_json_schema)
3049+
assert schema_3 == schema_dict
3050+
assert gen_config_3.response_mime_type == "application/json"
3051+
3052+
invalid_response_format = {"type": "invalid_type"}
3053+
gen_config_4 = llm._prepare_params(
3054+
stop=None,
3055+
response_format=invalid_response_format,
3056+
response_schema=schema_dict,
3057+
)
3058+
# Should fall back to response_schema
3059+
# response_json_schema may be converted to proto object, so convert to dict for comparison
3060+
schema_4 = _convert_proto_to_dict(gen_config_4.response_json_schema)
3061+
assert schema_4 == schema_dict
3062+
assert gen_config_4.response_mime_type == "application/json"
3063+
3064+
29603065
def test_is_new_gemini_model() -> None:
29613066
assert _is_gemini_3_or_later("gemini-3.0-pro") is True
29623067
assert _is_gemini_3_or_later("gemini-2.5-pro") is False

0 commit comments

Comments
 (0)