Skip to content

Commit 6ad51eb

Browse files
Introduce dspy.Reasoning to capture native reasoning from reasoning models (#8986)
* support for native reasoning in CoT for reasoning models * ruff and test * Introduce dspy.Reasoning to handle ChainOfThought on reasoning models * remove unintended file * fix * make reasoning string-like * increment * go * polish the docstring * automatically turn on reasoning for COT on reasoning model * comments * fix tests * fix * add dspy.Reasoning * comments * add comment for backward compatibility --------- Co-authored-by: arnavsinghvi11 <arnav11.singhvi@gmail.com>
1 parent b82f06b commit 6ad51eb

File tree

16 files changed

+626
-21
lines changed

16 files changed

+626
-21
lines changed

dspy/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66

77
from dspy.evaluate import Evaluate # isort: skip
88
from dspy.clients import * # isort: skip
9-
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, File, History, Type, Tool, ToolCalls, Code # isort: skip
9+
from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, File, History, Type, Tool, ToolCalls, Code, Reasoning # isort: skip
1010
from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging
1111
from dspy.utils.asyncify import asyncify
1212
from dspy.utils.syncify import syncify

dspy/adapters/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
from dspy.adapters.chat_adapter import ChatAdapter
33
from dspy.adapters.json_adapter import JSONAdapter
44
from dspy.adapters.two_step_adapter import TwoStepAdapter
5-
from dspy.adapters.types import Audio, Code, File, History, Image, Tool, ToolCalls, Type
5+
from dspy.adapters.types import Audio, Code, File, History, Image, Reasoning, Tool, ToolCalls, Type
66
from dspy.adapters.xml_adapter import XMLAdapter
77

88
__all__ = [
@@ -19,4 +19,5 @@
1919
"TwoStepAdapter",
2020
"Tool",
2121
"ToolCalls",
22+
"Reasoning",
2223
]

dspy/adapters/base.py

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66

77
from dspy.adapters.types import History, Type
88
from dspy.adapters.types.base_type import split_message_content_for_custom_types
9+
from dspy.adapters.types.reasoning import Reasoning
910
from dspy.adapters.types.tool import Tool, ToolCalls
1011
from dspy.experimental import Citations
1112
from dspy.signatures.signature import Signature
@@ -16,7 +17,7 @@
1617
if TYPE_CHECKING:
1718
from dspy.clients.lm import LM
1819

19-
_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations]
20+
_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations, Reasoning]
2021

2122

2223
class Adapter:
@@ -99,14 +100,14 @@ def _call_preprocess(
99100

100101
return signature_for_native_function_calling
101102

102-
# Handle custom types that use native response
103+
# Handle custom types that use native LM features, e.g., reasoning, citations, etc.
103104
for name, field in signature.output_fields.items():
104105
if (
105106
isinstance(field.annotation, type)
106107
and issubclass(field.annotation, Type)
107108
and field.annotation in self.native_response_types
108109
):
109-
signature = signature.delete(name)
110+
signature = field.annotation.adapt_to_native_lm_feature(signature, name, lm, lm_kwargs)
110111

111112
return signature
112113

@@ -116,6 +117,7 @@ def _call_postprocess(
116117
original_signature: type[Signature],
117118
outputs: list[dict[str, Any] | str],
118119
lm: "LM",
120+
lm_kwargs: dict[str, Any],
119121
) -> list[dict[str, Any]]:
120122
values = []
121123

@@ -152,14 +154,16 @@ def _call_postprocess(
152154
]
153155
value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls)
154156

155-
# Parse custom types that does not rely on the adapter parsing
157+
# Parse custom types that does not rely on the `Adapter.parse()` method
156158
for name, field in original_signature.output_fields.items():
157159
if (
158160
isinstance(field.annotation, type)
159161
and issubclass(field.annotation, Type)
160162
and field.annotation in self.native_response_types
161163
):
162-
value[name] = field.annotation.parse_lm_response(output)
164+
parsed_value = field.annotation.parse_lm_response(output)
165+
if parsed_value is not None:
166+
value[name] = parsed_value
163167

164168
if output_logprobs:
165169
value["logprobs"] = output_logprobs
@@ -196,7 +200,7 @@ def __call__(
196200
inputs = self.format(processed_signature, demos, inputs)
197201

198202
outputs = lm(messages=inputs, **lm_kwargs)
199-
return self._call_postprocess(processed_signature, signature, outputs, lm)
203+
return self._call_postprocess(processed_signature, signature, outputs, lm, lm_kwargs)
200204

201205
async def acall(
202206
self,
@@ -210,7 +214,7 @@ async def acall(
210214
inputs = self.format(processed_signature, demos, inputs)
211215

212216
outputs = await lm.acall(messages=inputs, **lm_kwargs)
213-
return self._call_postprocess(processed_signature, signature, outputs, lm)
217+
return self._call_postprocess(processed_signature, signature, outputs, lm, lm_kwargs)
214218

215219
def format(
216220
self,

dspy/adapters/types/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
from dspy.adapters.types.file import File
55
from dspy.adapters.types.history import History
66
from dspy.adapters.types.image import Image
7+
from dspy.adapters.types.reasoning import Reasoning
78
from dspy.adapters.types.tool import Tool, ToolCalls
89

9-
__all__ = ["History", "Image", "Audio", "File", "Type", "Tool", "ToolCalls", "Code"]
10+
__all__ = ["History", "Image", "Audio", "File", "Type", "Tool", "ToolCalls", "Code", "Reasoning"]

dspy/adapters/types/base_type.py

Lines changed: 30 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,15 @@
11
import json
22
import re
3-
from typing import Any, Optional, get_args, get_origin
3+
from typing import TYPE_CHECKING, Any, Optional, get_args, get_origin
44

55
import json_repair
66
import pydantic
77
from litellm import ModelResponseStream
88

9+
if TYPE_CHECKING:
10+
from dspy.clients.lm import LM
11+
from dspy.signatures.signature import Signature
12+
913
CUSTOM_TYPE_START_IDENTIFIER = "<<CUSTOM-TYPE-START-IDENTIFIER>>"
1014
CUSTOM_TYPE_END_IDENTIFIER = "<<CUSTOM-TYPE-END-IDENTIFIER>>"
1115

@@ -70,6 +74,31 @@ def serialize_model(self):
7074
)
7175
return formatted
7276

77+
@classmethod
78+
def adapt_to_native_lm_feature(
79+
cls,
80+
signature: type["Signature"],
81+
field_name: str,
82+
lm: "LM",
83+
lm_kwargs: dict[str, Any],
84+
) -> type["Signature"]:
85+
"""Adapt the custom type to the native LM feature if possible.
86+
87+
When the LM and configuration supports the related native LM feature, e.g., native tool calling, native
88+
reasoning, etc., we adapt the signature and `lm_kwargs` to enable the native LM feature.
89+
90+
Args:
91+
signature: The DSPy signature for the LM call.
92+
field_name: The name of the field in the signature to adapt to the native LM feature.
93+
lm: The LM instance.
94+
lm_kwargs: The keyword arguments for the LM call, subject to in-place updates if adaptation if required.
95+
96+
Returns:
97+
The adapted signature. If the custom type is not natively supported by the LM, return the original
98+
signature.
99+
"""
100+
return signature
101+
73102
@classmethod
74103
def is_streamable(cls) -> bool:
75104
"""Whether the custom type is streamable."""

dspy/adapters/types/citation.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -167,6 +167,12 @@ def __getitem__(self, index):
167167
"""Allow indexing into citations."""
168168
return self.citations[index]
169169

170+
@classmethod
171+
def adapt_to_native_lm_feature(cls, signature, field_name, lm, lm_kwargs) -> bool:
172+
if lm.model.startswith("anthropic/"):
173+
return signature.delete(field_name)
174+
return signature
175+
170176
@classmethod
171177
def is_streamable(cls) -> bool:
172178
"""Whether the Citations type is streamable."""

dspy/adapters/types/reasoning.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from typing import TYPE_CHECKING, Any, Optional
2+
3+
import litellm
4+
import pydantic
5+
6+
from dspy.adapters.types.base_type import Type
7+
8+
if TYPE_CHECKING:
9+
from dspy.clients.lm import LM
10+
from dspy.signatures.signature import Signature
11+
12+
13+
class Reasoning(Type):
14+
"""Reasoning type in DSPy.
15+
16+
This type is useful when you want the DSPy output to include the reasoning of the LM. We build this type so that
17+
DSPy can support the reasoning model and non-reasoning model with the same code.
18+
19+
This is a str-like type, you can convert a string directly to a Reasoning object, and from DSPy adapters'
20+
perspective, `Reasoning` is treated as a string.
21+
"""
22+
23+
content: str
24+
25+
def format(self):
26+
return f"{self.content}"
27+
28+
@pydantic.model_validator(mode="before")
29+
@classmethod
30+
def validate_input(cls, data: Any):
31+
if isinstance(data, cls):
32+
return data
33+
34+
if isinstance(data, str):
35+
return {"content": data}
36+
37+
if isinstance(data, dict):
38+
if "content" not in data:
39+
raise ValueError("`content` field is required for `dspy.Reasoning`")
40+
if not isinstance(data["content"], str):
41+
raise ValueError(f"`content` field must be a string, but received type: {type(data['content'])}")
42+
return {"content": data["content"]}
43+
44+
raise ValueError(f"Received invalid value for `dspy.Reasoning`: {data}")
45+
46+
@classmethod
47+
def adapt_to_native_lm_feature(
48+
cls,
49+
signature: type["Signature"],
50+
field_name: str,
51+
lm: "LM",
52+
lm_kwargs: dict[str, Any],
53+
) -> type["Signature"]:
54+
if "reasoning_effort" in lm_kwargs:
55+
# `lm_kwargs` overrides `lm.kwargs`.
56+
reasoning_effort = lm_kwargs["reasoning_effort"]
57+
elif "reasoning_effort" in lm.kwargs:
58+
reasoning_effort = lm.kwargs["reasoning_effort"]
59+
else:
60+
# Turn on the native reasoning explicitly if Reasoning field is present in the signature and no explicit
61+
# reasoning effort is set in `lm_kwargs` or `lm.kwargs`.
62+
reasoning_effort = "low"
63+
64+
if reasoning_effort is None or not litellm.supports_reasoning(lm.model):
65+
# If users explicitly set `reasoning_effort` to None or the LM doesn't support reasoning, we don't enable
66+
# native reasoning.
67+
return signature
68+
69+
if "gpt-5" in lm.model and lm.model_type == "chat":
70+
# There is a caveat of Litellm as 1.79.0 that when using the chat completion API on GPT-5 family models,
71+
# the reasoning content is not available in the response. As a workaround, we don't enable the native
72+
# reasoning feature for GPT-5 family models when using the chat completion API.
73+
# Litellm issue: https://github.com/BerriAI/litellm/issues/14748
74+
return signature
75+
76+
lm_kwargs["reasoning_effort"] = reasoning_effort
77+
# Delete the reasoning field from the signature to use the native reasoning feature.
78+
return signature.delete(field_name)
79+
80+
@classmethod
81+
def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Reasoning"]:
82+
"""Parse the LM response into a Reasoning object."""
83+
if "reasoning_content" in response:
84+
return Reasoning(content=response["reasoning_content"])
85+
return None
86+
87+
@classmethod
88+
def parse_stream_chunk(cls, chunk) -> str | None:
89+
"""
90+
Parse a stream chunk into reasoning content if available.
91+
92+
Args:
93+
chunk: A stream chunk from the LM.
94+
95+
Returns:
96+
The reasoning content (str) if available, None otherwise.
97+
"""
98+
try:
99+
if choices := getattr(chunk, "choices", None):
100+
return getattr(choices[0].delta, "reasoning_content", None)
101+
except Exception:
102+
return None
103+
104+
@classmethod
105+
def is_streamable(cls) -> bool:
106+
return True
107+
108+
def __repr__(self) -> str:
109+
return f"{self.content!r}"
110+
111+
def __str__(self) -> str:
112+
return self.content
113+
114+
def __eq__(self, other: object) -> bool:
115+
if isinstance(other, Reasoning):
116+
return self.content == other.content
117+
if isinstance(other, str):
118+
return self.content == other

dspy/adapters/utils.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from pydantic.fields import FieldInfo
1313

1414
from dspy.adapters.types.base_type import Type as DspyType
15+
from dspy.adapters.types.reasoning import Reasoning
1516
from dspy.signatures.utils import get_dspy_field_type
1617

1718

@@ -84,7 +85,7 @@ def move_type_to_front(d):
8485
def translate_field_type(field_name, field_info):
8586
field_type = field_info.annotation
8687

87-
if get_dspy_field_type(field_info) == "input" or field_type is str:
88+
if get_dspy_field_type(field_info) == "input" or field_type is str or field_type is Reasoning:
8889
desc = ""
8990
elif field_type is bool:
9091
desc = "must be True or False"
@@ -190,6 +191,10 @@ def get_annotation_name(annotation):
190191
origin = get_origin(annotation)
191192
args = get_args(annotation)
192193
if origin is None:
194+
if annotation is Reasoning:
195+
# Keep backward compatibility with the old behavior in `dspy.ChainOfThought`, where reasoning
196+
# field type is treated as a string.
197+
return "str"
193198
if hasattr(annotation, "__name__"):
194199
return annotation.__name__
195200
else:

dspy/clients/base_lm.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,10 @@ def _process_completion(self, response, merged_kwargs):
204204
for c in response.choices:
205205
output = {}
206206
output["text"] = c.message.content if hasattr(c, "message") else c["text"]
207+
208+
if hasattr(c, "message") and hasattr(c.message, "reasoning_content") and c.message.reasoning_content:
209+
output["reasoning_content"] = c.message.reasoning_content
210+
207211
if merged_kwargs.get("logprobs"):
208212
output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"]
209213
if hasattr(c, "message") and getattr(c.message, "tool_calls", None):
@@ -219,7 +223,6 @@ def _process_completion(self, response, merged_kwargs):
219223
if all(len(output) == 1 for output in outputs):
220224
# Return a list if every output only has "text" key
221225
outputs = [output["text"] for output in outputs]
222-
223226
return outputs
224227

225228
def _extract_citations_from_response(self, choice):

dspy/clients/lm.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,10 @@ def _convert_chat_request_to_responses_request(request: dict[str, Any]):
493493
for item in c:
494494
content_blocks.append(_convert_content_item_to_responses_format(item))
495495
request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}]
496+
# Convert `reasoning_effort` to reasoning format supported by the Responses API
497+
if "reasoning_effort" in request:
498+
effort = request.pop("reasoning_effort")
499+
request["reasoning"] = {"effort": effort, "summary": "auto"}
496500

497501
# Convert `response_format` to `text.format` for Responses API
498502
if "response_format" in request:

0 commit comments

Comments
 (0)