Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions libs/genai/langchain_google_genai/_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -230,6 +230,15 @@ class _BaseGoogleGenerativeAI(BaseModel):
```
""" # noqa: E501

seed: int | None = Field(
default=None,
)
"""
Seed used in decoding.

If not set, the request uses a randomly generated seed.
"""

@property
def lc_secrets(self) -> dict[str, str]:
# Either could contain the API key
Expand Down
1 change: 1 addition & 0 deletions libs/genai/langchain_google_genai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -2213,6 +2213,7 @@ def _prepare_params(
or thinking_level is not None
else None
),
"seed": self.seed,
}.items()
if v is not None
}
Expand Down
1 change: 1 addition & 0 deletions libs/genai/langchain_google_genai/llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ def validate_environment(self) -> Self:
transport=self.transport,
additional_headers=self.additional_headers,
safety_settings=self.safety_settings,
seed=self.seed,
)

return self
Expand Down
31 changes: 31 additions & 0 deletions libs/genai/tests/integration_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -1384,3 +1384,34 @@ def test_chat_google_genai_invoke_respects_max_tokens(max_tokens: int) -> None:
assert output_tokens <= max_tokens, (
f"Expected output_tokens <= {max_tokens}, got {output_tokens}"
)


@pytest.mark.flaky(retries=3, delay=1)
def test_seed_provides_reproducibility() -> None:
llm = ChatGoogleGenerativeAI(model=_MODEL, thinking_budget=0)
n_generations = 3

# Explicit seed improves reproducibility
actual_results = set()
for _ in range(n_generations):
result = llm.invoke(
"Provide a number between 0 and 100.",
generation_config={
"top_p": 1,
"temperature": 1.0,
"max_output_tokens": 10,
"seed": 42,
},
)
actual_results.add(result.content)
assert len(actual_results) == 1

# Lack of seed means it's very unlikely to get the same number 3 times in a row
actual_results = set()
for _ in range(n_generations):
result = llm.invoke(
"Provide a number between 0 and 100.",
generation_config={"top_p": 1, "temperature": 1.0, "max_output_tokens": 10},
)
actual_results.add(result.content)
assert len(actual_results) > 1
17 changes: 17 additions & 0 deletions libs/genai/tests/unit_tests/test_chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -3170,3 +3170,20 @@ def test_kwargs_override_thinking_level() -> None:
config = llm._prepare_params(stop=None, thinking_level="high")
assert config.thinking_config is not None
assert config.thinking_config.thinking_level == "high"


def test_seed_initialization() -> None:
# Test explicitly provided seed
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key=SecretStr(FAKE_API_KEY),
seed=42,
)
assert llm.seed == 42

# Test default seed
llm = ChatGoogleGenerativeAI(
model=MODEL_NAME,
google_api_key=SecretStr(FAKE_API_KEY),
)
assert llm.seed is None