Skip to content

Commit 15cbbd4

Browse files
fix(LM): Change default temperature and max_tokens to be None (#8908)
* Set default parameters for reasoning model if not passed in. * change warning level * ruff * remove default temperature and max_tokens * Update tests to check that reasoning models with no params work * Update dspy/clients/lm.py Co-authored-by: Chen Qian <chen.qian@databricks.com> --------- Co-authored-by: Chen Qian <chen.qian@databricks.com>
1 parent eacbcdd commit 15cbbd4

File tree

2 files changed

+13
-5
lines changed

2 files changed

+13
-5
lines changed

dspy/clients/lm.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -31,8 +31,8 @@ def __init__(
3131
self,
3232
model: str,
3333
model_type: Literal["chat", "text", "responses"] = "chat",
34-
temperature: float = 0.0,
35-
max_tokens: int = 4000,
34+
temperature: float | None = None,
35+
max_tokens: int | None = None,
3636
cache: bool = True,
3737
callbacks: list[BaseCallback] | None = None,
3838
num_retries: int = 3,
@@ -88,9 +88,10 @@ def __init__(
8888
model_pattern = re.match(r"^(?:o[1345]|gpt-5)(?:-(?:mini|nano))?", model_family)
8989

9090
if model_pattern:
91-
if max_tokens < 16000 or temperature != 1.0:
91+
92+
if (temperature and temperature != 1.0) or (max_tokens and max_tokens < 16000):
9293
raise ValueError(
93-
"OpenAI's reasoning models require passing temperature=1.0 and max_tokens >= 16000 to "
94+
"OpenAI's reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None to "
9495
"`dspy.LM(...)`, e.g., dspy.LM('openai/gpt-5', temperature=1.0, max_tokens=16000)"
9596
)
9697
self.kwargs = dict(temperature=temperature, max_completion_tokens=max_tokens, **kwargs)

tests/clients/test_lm.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -318,7 +318,7 @@ def test_reasoning_model_requirements(model_name):
318318
# Should raise assertion error if temperature or max_tokens requirements not met
319319
with pytest.raises(
320320
ValueError,
321-
match="reasoning models require passing temperature=1.0 and max_tokens >= 16000",
321+
match="reasoning models require passing temperature=1.0 or None and max_tokens >= 16000 or None",
322322
):
323323
dspy.LM(
324324
model=model_name,
@@ -334,6 +334,13 @@ def test_reasoning_model_requirements(model_name):
334334
)
335335
assert lm.kwargs["max_completion_tokens"] == 16_000
336336

337+
# Should pass with no parameters
338+
lm = dspy.LM(
339+
model=model_name,
340+
)
341+
assert lm.kwargs["temperature"] == None
342+
assert lm.kwargs["max_completion_tokens"] == None
343+
337344

338345
def test_dump_state():
339346
lm = dspy.LM(

0 commit comments

Comments
 (0)