Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 39 additions & 13 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -896,6 +896,8 @@
except (AttributeError, TypeError):
thought_sig = None

has_function_call = hasattr(part, "function_call") and part.function_call

Check failure on line 899 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F841)

langchain_google_genai/chat_models.py:899:9: F841 Local variable `has_function_call` is assigned to but never used

Check failure on line 899 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.10

Ruff (F841)

langchain_google_genai/chat_models.py:899:9: F841 Local variable `has_function_call` is assigned to but never used

if hasattr(part, "thought") and part.thought:
thinking_message = {
"type": "thinking",
Expand Down Expand Up @@ -1031,21 +1033,37 @@

# If this function_call Part has a signature, track it separately
if thought_sig:
if _FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY not in additional_kwargs:
additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY] = {}
additional_kwargs[_FUNCTION_CALL_THOUGHT_SIGNATURES_MAP_KEY][
tool_call_id
] = (
_bytes_to_base64(thought_sig)
if isinstance(thought_sig, bytes)
else thought_sig
)
sig_block = {
"type": "function_call_signature",
"signature": thought_sig,
}
function_call_signatures.append(sig_block)

Check failure on line 1040 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:1040:17: F821 Undefined name `function_call_signatures`

Check failure on line 1040 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.10

Ruff (F821)

langchain_google_genai/chat_models.py:1040:17: F821 Undefined name `function_call_signatures`

# Add function call signatures to content only if there's already other content
# This preserves backward compatibility where content is "" for
# function-only responses
if function_call_signatures and content is not None:

Check failure on line 1045 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:1045:12: F821 Undefined name `function_call_signatures`

Check failure on line 1045 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.10

Ruff (F821)

langchain_google_genai/chat_models.py:1045:12: F821 Undefined name `function_call_signatures`
for sig_block in function_call_signatures:

Check failure on line 1046 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:1046:30: F821 Undefined name `function_call_signatures`

Check failure on line 1046 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.10

Ruff (F821)

langchain_google_genai/chat_models.py:1046:30: F821 Undefined name `function_call_signatures`
content = _append_to_content(content, sig_block)

if content is None:
if _is_gemini_3_or_later(model_name or ""):
content = []
else:
content = ""
content = ""

if (
hasattr(response_candidate, "logprobs_result")
and response_candidate.logprobs_result
):
# Note: logprobs is flaky, sometimes available, sometimes not
# https://discuss.ai.google.dev/t/logprobs-is-not-enabled-for-gemini-models/107989/15
response_metadata["logprobs"] = MessageToDict(

Check failure on line 1058 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.12

Ruff (F821)

langchain_google_genai/chat_models.py:1058:41: F821 Undefined name `MessageToDict`

Check failure on line 1058 in libs/genai/langchain_google_genai/chat_models.py

View workflow job for this annotation

GitHub Actions / cd libs/genai / - / make lint #3.10

Ruff (F821)

langchain_google_genai/chat_models.py:1058:41: F821 Undefined name `MessageToDict`
response_candidate.logprobs_result._pb,
preserving_proto_field_name=True,
)

if _is_gemini_3_or_later(model_name or ""):
content = []
else:
content = ""
if isinstance(content, list) and any(
isinstance(item, dict) and "executable_code" in item for item in content
):
Expand Down Expand Up @@ -1922,6 +1940,9 @@
stop: list[str] | None = None
"""Stop sequences for the model."""

logprobs: int | None = None
"""The number of logprobs to return."""

streaming: bool | None = None
"""Whether to stream responses from the model."""

Expand Down Expand Up @@ -2102,6 +2123,7 @@
"media_resolution": self.media_resolution,
"thinking_budget": self.thinking_budget,
"include_thoughts": self.include_thoughts,
"logprobs": self.logprobs,
"thinking_level": self.thinking_level,
}

Expand Down Expand Up @@ -2216,6 +2238,10 @@
}.items()
if v is not None
}
logprobs = getattr(self, "logprobs", None)
if logprobs:
gen_config["logprobs"] = logprobs
gen_config["response_logprobs"] = True
if generation_config:
gen_config = {**gen_config, **generation_config}

Expand Down
83 changes: 80 additions & 3 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,24 +147,98 @@ def test_initialization_inside_threadpool() -> None:
).result()


def test_client_transport() -> None:
def test_logprobs() -> None:
"""Test that logprobs parameter is set correctly and is in the response."""
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key=SecretStr("secret-api-key"),
logprobs=10,
)
assert llm.logprobs == 10

# Create proper mock response with logprobs_result
raw_response = {
"candidates": [
{
"content": {"parts": [{"text": "Test response"}]},
"finish_reason": 1,
"safety_ratings": [],
"logprobs_result": {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
},
}
],
"prompt_feedback": {"block_reason": 0, "safety_ratings": []},
"usage_metadata": {
"prompt_token_count": 5,
"candidates_token_count": 2,
"total_token_count": 7,
},
}
response = GenerateContentResponse(raw_response)

with patch(
"langchain_google_genai.chat_models._chat_with_retry"
) as mock_chat_with_retry:
mock_chat_with_retry.return_value = response
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key="test-key",
logprobs=1,
)
result = llm.invoke("test")
assert "logprobs" in result.response_metadata
assert result.response_metadata["logprobs"] == {
"top_candidates": [
{
"candidates": [
{"token": "Test", "log_probability": -0.1},
]
}
]
}

mock_chat_with_retry.assert_called_once()
request = mock_chat_with_retry.call_args.kwargs["request"]
assert request.generation_config.logprobs == 1
assert request.generation_config.response_logprobs is True


@pytest.mark.enable_socket
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceAsyncClient")
@patch("langchain_google_genai._genai_extension.v1betaGenerativeServiceClient")
def test_client_transport(mock_client: Mock, mock_async_client: Mock) -> None:
"""Test client transport configuration."""
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY)
mock_client.return_value.transport = Mock()
mock_client.return_value.transport.kind = "grpc"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
assert model.client.transport.kind == "grpc"

mock_client.return_value.transport.kind = "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key="fake-key", transport="rest"
)
assert model.client.transport.kind == "rest"

async def check_async_client() -> None:
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key=FAKE_API_KEY)
mock_async_client.return_value.transport = Mock()
mock_async_client.return_value.transport.kind = "grpc_asyncio"
model = ChatGoogleGenerativeAI(model=MODEL_NAME, google_api_key="fake-key")
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

# Test auto conversion of transport to "grpc_asyncio" from "rest"
model = ChatGoogleGenerativeAI(
model=MODEL_NAME, google_api_key=FAKE_API_KEY, transport="rest"
)
model.async_client_running = None
_ = model.async_client
assert model.async_client.transport.kind == "grpc_asyncio"

asyncio.run(check_async_client())
Expand All @@ -178,6 +252,7 @@ def test_initalization_without_async() -> None:
assert chat.async_client is None


@pytest.mark.enable_socket
def test_initialization_with_async() -> None:
async def initialize_chat_with_async_client() -> ChatGoogleGenerativeAI:
model = ChatGoogleGenerativeAI(
Expand Down Expand Up @@ -1720,6 +1795,7 @@ def test_grounding_metadata_multiple_parts() -> None:
assert grounding["grounding_supports"][0]["segment"]["part_index"] == 1


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down Expand Up @@ -1846,6 +1922,7 @@ def mock_stream() -> Iterator[GenerateContentResponse]:
assert "timeout" not in call_kwargs


@pytest.mark.enable_socket
@pytest.mark.parametrize(
"is_async,mock_target,method_name",
[
Expand Down
Loading