diff --git a/dspy/__init__.py b/dspy/__init__.py index 01aad36757..3de46932b8 100644 --- a/dspy/__init__.py +++ b/dspy/__init__.py @@ -6,7 +6,7 @@ from dspy.evaluate import Evaluate # isort: skip from dspy.clients import * # isort: skip -from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, File, History, Type, Tool, ToolCalls, Code # isort: skip +from dspy.adapters import Adapter, ChatAdapter, JSONAdapter, XMLAdapter, TwoStepAdapter, Image, Audio, File, History, Type, Tool, ToolCalls, Code, Reasoning # isort: skip from dspy.utils.logging_utils import configure_dspy_loggers, disable_logging, enable_logging from dspy.utils.asyncify import asyncify from dspy.utils.syncify import syncify diff --git a/dspy/adapters/__init__.py b/dspy/adapters/__init__.py index fded25398f..c217d7260e 100644 --- a/dspy/adapters/__init__.py +++ b/dspy/adapters/__init__.py @@ -2,7 +2,7 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.json_adapter import JSONAdapter from dspy.adapters.two_step_adapter import TwoStepAdapter -from dspy.adapters.types import Audio, Code, File, History, Image, Tool, ToolCalls, Type +from dspy.adapters.types import Audio, Code, File, History, Image, Reasoning, Tool, ToolCalls, Type from dspy.adapters.xml_adapter import XMLAdapter __all__ = [ @@ -19,4 +19,5 @@ "TwoStepAdapter", "Tool", "ToolCalls", + "Reasoning", ] diff --git a/dspy/adapters/base.py b/dspy/adapters/base.py index 555164eb70..8696697d3a 100644 --- a/dspy/adapters/base.py +++ b/dspy/adapters/base.py @@ -6,6 +6,7 @@ from dspy.adapters.types import History, Type from dspy.adapters.types.base_type import split_message_content_for_custom_types +from dspy.adapters.types.reasoning import Reasoning from dspy.adapters.types.tool import Tool, ToolCalls from dspy.experimental import Citations from dspy.signatures.signature import Signature @@ -16,7 +17,7 @@ if TYPE_CHECKING: from dspy.clients.lm import LM -_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations] +_DEFAULT_NATIVE_RESPONSE_TYPES = [Citations, Reasoning] class Adapter: @@ -99,14 +100,14 @@ def _call_preprocess( return signature_for_native_function_calling - # Handle custom types that use native response + # Handle custom types that use native LM features, e.g., reasoning, citations, etc. for name, field in signature.output_fields.items(): if ( isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types ): - signature = signature.delete(name) + signature = field.annotation.adapt_to_native_lm_feature(signature, name, lm, lm_kwargs) return signature @@ -116,6 +117,7 @@ def _call_postprocess( original_signature: type[Signature], outputs: list[dict[str, Any] | str], lm: "LM", + lm_kwargs: dict[str, Any], ) -> list[dict[str, Any]]: values = [] @@ -152,14 +154,16 @@ def _call_postprocess( ] value[tool_call_output_field_name] = ToolCalls.from_dict_list(tool_calls) - # Parse custom types that does not rely on the adapter parsing + # Parse custom types that does not rely on the `Adapter.parse()` method for name, field in original_signature.output_fields.items(): if ( isinstance(field.annotation, type) and issubclass(field.annotation, Type) and field.annotation in self.native_response_types ): - value[name] = field.annotation.parse_lm_response(output) + parsed_value = field.annotation.parse_lm_response(output) + if parsed_value is not None: + value[name] = parsed_value if output_logprobs: value["logprobs"] = output_logprobs @@ -196,7 +200,7 @@ def __call__( inputs = self.format(processed_signature, demos, inputs) outputs = lm(messages=inputs, **lm_kwargs) - return self._call_postprocess(processed_signature, signature, outputs, lm) + return self._call_postprocess(processed_signature, signature, outputs, lm, lm_kwargs) async def acall( self, @@ -210,7 +214,7 @@ async def acall( inputs = self.format(processed_signature, demos, inputs) outputs = await lm.acall(messages=inputs, **lm_kwargs) - return self._call_postprocess(processed_signature, signature, outputs, lm) + return self._call_postprocess(processed_signature, signature, outputs, lm, lm_kwargs) def format( self, diff --git a/dspy/adapters/types/__init__.py b/dspy/adapters/types/__init__.py index eb1481c862..5ec8043021 100644 --- a/dspy/adapters/types/__init__.py +++ b/dspy/adapters/types/__init__.py @@ -4,6 +4,7 @@ from dspy.adapters.types.file import File from dspy.adapters.types.history import History from dspy.adapters.types.image import Image +from dspy.adapters.types.reasoning import Reasoning from dspy.adapters.types.tool import Tool, ToolCalls -__all__ = ["History", "Image", "Audio", "File", "Type", "Tool", "ToolCalls", "Code"] +__all__ = ["History", "Image", "Audio", "File", "Type", "Tool", "ToolCalls", "Code", "Reasoning"] diff --git a/dspy/adapters/types/base_type.py b/dspy/adapters/types/base_type.py index b7004de537..afcacb47ad 100644 --- a/dspy/adapters/types/base_type.py +++ b/dspy/adapters/types/base_type.py @@ -1,11 +1,15 @@ import json import re -from typing import Any, Optional, get_args, get_origin +from typing import TYPE_CHECKING, Any, Optional, get_args, get_origin import json_repair import pydantic from litellm import ModelResponseStream +if TYPE_CHECKING: + from dspy.clients.lm import LM + from dspy.signatures.signature import Signature + CUSTOM_TYPE_START_IDENTIFIER = "<>" CUSTOM_TYPE_END_IDENTIFIER = "<>" @@ -70,6 +74,31 @@ def serialize_model(self): ) return formatted + @classmethod + def adapt_to_native_lm_feature( + cls, + signature: type["Signature"], + field_name: str, + lm: "LM", + lm_kwargs: dict[str, Any], + ) -> type["Signature"]: + """Adapt the custom type to the native LM feature if possible. + + When the LM and configuration supports the related native LM feature, e.g., native tool calling, native + reasoning, etc., we adapt the signature and `lm_kwargs` to enable the native LM feature. + + Args: + signature: The DSPy signature for the LM call. + field_name: The name of the field in the signature to adapt to the native LM feature. + lm: The LM instance. + lm_kwargs: The keyword arguments for the LM call, subject to in-place updates if adaptation if required. + + Returns: + The adapted signature. If the custom type is not natively supported by the LM, return the original + signature. + """ + return signature + @classmethod def is_streamable(cls) -> bool: """Whether the custom type is streamable.""" diff --git a/dspy/adapters/types/citation.py b/dspy/adapters/types/citation.py index 4268f194da..e6ca01b2c4 100644 --- a/dspy/adapters/types/citation.py +++ b/dspy/adapters/types/citation.py @@ -167,6 +167,12 @@ def __getitem__(self, index): """Allow indexing into citations.""" return self.citations[index] + @classmethod + def adapt_to_native_lm_feature(cls, signature, field_name, lm, lm_kwargs) -> bool: + if lm.model.startswith("anthropic/"): + return signature.delete(field_name) + return signature + @classmethod def is_streamable(cls) -> bool: """Whether the Citations type is streamable.""" diff --git a/dspy/adapters/types/reasoning.py b/dspy/adapters/types/reasoning.py new file mode 100644 index 0000000000..6f468548e1 --- /dev/null +++ b/dspy/adapters/types/reasoning.py @@ -0,0 +1,118 @@ +from typing import TYPE_CHECKING, Any, Optional + +import litellm +import pydantic + +from dspy.adapters.types.base_type import Type + +if TYPE_CHECKING: + from dspy.clients.lm import LM + from dspy.signatures.signature import Signature + + +class Reasoning(Type): + """Reasoning type in DSPy. + + This type is useful when you want the DSPy output to include the reasoning of the LM. We build this type so that + DSPy can support the reasoning model and non-reasoning model with the same code. + + This is a str-like type, you can convert a string directly to a Reasoning object, and from DSPy adapters' + perspective, `Reasoning` is treated as a string. + """ + + content: str + + def format(self): + return f"{self.content}" + + @pydantic.model_validator(mode="before") + @classmethod + def validate_input(cls, data: Any): + if isinstance(data, cls): + return data + + if isinstance(data, str): + return {"content": data} + + if isinstance(data, dict): + if "content" not in data: + raise ValueError("`content` field is required for `dspy.Reasoning`") + if not isinstance(data["content"], str): + raise ValueError(f"`content` field must be a string, but received type: {type(data['content'])}") + return {"content": data["content"]} + + raise ValueError(f"Received invalid value for `dspy.Reasoning`: {data}") + + @classmethod + def adapt_to_native_lm_feature( + cls, + signature: type["Signature"], + field_name: str, + lm: "LM", + lm_kwargs: dict[str, Any], + ) -> type["Signature"]: + if "reasoning_effort" in lm_kwargs: + # `lm_kwargs` overrides `lm.kwargs`. + reasoning_effort = lm_kwargs["reasoning_effort"] + elif "reasoning_effort" in lm.kwargs: + reasoning_effort = lm.kwargs["reasoning_effort"] + else: + # Turn on the native reasoning explicitly if Reasoning field is present in the signature and no explicit + # reasoning effort is set in `lm_kwargs` or `lm.kwargs`. + reasoning_effort = "low" + + if reasoning_effort is None or not litellm.supports_reasoning(lm.model): + # If users explicitly set `reasoning_effort` to None or the LM doesn't support reasoning, we don't enable + # native reasoning. + return signature + + if "gpt-5" in lm.model and lm.model_type == "chat": + # There is a caveat of Litellm as 1.79.0 that when using the chat completion API on GPT-5 family models, + # the reasoning content is not available in the response. As a workaround, we don't enable the native + # reasoning feature for GPT-5 family models when using the chat completion API. + # Litellm issue: https://github.com/BerriAI/litellm/issues/14748 + return signature + + lm_kwargs["reasoning_effort"] = reasoning_effort + # Delete the reasoning field from the signature to use the native reasoning feature. + return signature.delete(field_name) + + @classmethod + def parse_lm_response(cls, response: str | dict[str, Any]) -> Optional["Reasoning"]: + """Parse the LM response into a Reasoning object.""" + if "reasoning_content" in response: + return Reasoning(content=response["reasoning_content"]) + return None + + @classmethod + def parse_stream_chunk(cls, chunk) -> str | None: + """ + Parse a stream chunk into reasoning content if available. + + Args: + chunk: A stream chunk from the LM. + + Returns: + The reasoning content (str) if available, None otherwise. + """ + try: + if choices := getattr(chunk, "choices", None): + return getattr(choices[0].delta, "reasoning_content", None) + except Exception: + return None + + @classmethod + def is_streamable(cls) -> bool: + return True + + def __repr__(self) -> str: + return f"{self.content!r}" + + def __str__(self) -> str: + return self.content + + def __eq__(self, other: object) -> bool: + if isinstance(other, Reasoning): + return self.content == other.content + if isinstance(other, str): + return self.content == other diff --git a/dspy/adapters/utils.py b/dspy/adapters/utils.py index f38a77ee8a..fc87d2d6e0 100644 --- a/dspy/adapters/utils.py +++ b/dspy/adapters/utils.py @@ -12,6 +12,7 @@ from pydantic.fields import FieldInfo from dspy.adapters.types.base_type import Type as DspyType +from dspy.adapters.types.reasoning import Reasoning from dspy.signatures.utils import get_dspy_field_type @@ -84,7 +85,7 @@ def move_type_to_front(d): def translate_field_type(field_name, field_info): field_type = field_info.annotation - if get_dspy_field_type(field_info) == "input" or field_type is str: + if get_dspy_field_type(field_info) == "input" or field_type is str or field_type is Reasoning: desc = "" elif field_type is bool: desc = "must be True or False" @@ -190,6 +191,10 @@ def get_annotation_name(annotation): origin = get_origin(annotation) args = get_args(annotation) if origin is None: + if annotation is Reasoning: + # Keep backward compatibility with the old behavior in `dspy.ChainOfThought`, where reasoning + # field type is treated as a string. + return "str" if hasattr(annotation, "__name__"): return annotation.__name__ else: diff --git a/dspy/clients/base_lm.py b/dspy/clients/base_lm.py index 89f453f8a6..e8b893a09d 100644 --- a/dspy/clients/base_lm.py +++ b/dspy/clients/base_lm.py @@ -204,6 +204,10 @@ def _process_completion(self, response, merged_kwargs): for c in response.choices: output = {} output["text"] = c.message.content if hasattr(c, "message") else c["text"] + + if hasattr(c, "message") and hasattr(c.message, "reasoning_content") and c.message.reasoning_content: + output["reasoning_content"] = c.message.reasoning_content + if merged_kwargs.get("logprobs"): output["logprobs"] = c.logprobs if hasattr(c, "logprobs") else c["logprobs"] if hasattr(c, "message") and getattr(c.message, "tool_calls", None): @@ -219,7 +223,6 @@ def _process_completion(self, response, merged_kwargs): if all(len(output) == 1 for output in outputs): # Return a list if every output only has "text" key outputs = [output["text"] for output in outputs] - return outputs def _extract_citations_from_response(self, choice): diff --git a/dspy/clients/lm.py b/dspy/clients/lm.py index 81d968703e..f4885d470e 100644 --- a/dspy/clients/lm.py +++ b/dspy/clients/lm.py @@ -493,6 +493,10 @@ def _convert_chat_request_to_responses_request(request: dict[str, Any]): for item in c: content_blocks.append(_convert_content_item_to_responses_format(item)) request["input"] = [{"role": msg.get("role", "user"), "content": content_blocks}] + # Convert `reasoning_effort` to reasoning format supported by the Responses API + if "reasoning_effort" in request: + effort = request.pop("reasoning_effort") + request["reasoning"] = {"effort": effort, "summary": "auto"} # Convert `response_format` to `text.format` for Responses API if "response_format" in request: diff --git a/dspy/streaming/streaming_listener.py b/dspy/streaming/streaming_listener.py index 819ac2fd53..1ea93c2fd8 100644 --- a/dspy/streaming/streaming_listener.py +++ b/dspy/streaming/streaming_listener.py @@ -134,13 +134,6 @@ def receive(self, chunk: ModelResponseStream): else: return - try: - chunk_message = chunk.choices[0].delta.content - if chunk_message is None: - return - except Exception: - return - # Handle custom streamable types if self._output_type and issubclass(self._output_type, Type) and self._output_type.is_streamable(): if parsed_chunk := self._output_type.parse_stream_chunk(chunk): @@ -151,6 +144,14 @@ def receive(self, chunk: ModelResponseStream): is_last_chunk=self.stream_end, ) + # For non-custom streamable types, the streaming chunks come from the content field of the ModelResponseStream. + try: + chunk_message = chunk.choices[0].delta.content + if chunk_message is None: + return + except Exception: + return + if chunk_message and start_identifier in chunk_message and not isinstance(settings.adapter, JSONAdapter): # If the cache is hit, the chunk_message could be the full response. When it happens we can # directly end the stream listening. In some models like gemini, each stream chunk can be multiple diff --git a/tests/adapters/test_chat_adapter.py b/tests/adapters/test_chat_adapter.py index 1c41f6346a..bbb7287572 100644 --- a/tests/adapters/test_chat_adapter.py +++ b/tests/adapters/test_chat_adapter.py @@ -611,6 +611,44 @@ def get_weather(city: str) -> str: ) +def test_chat_adapter_native_reasoning(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + reasoning: dspy.Reasoning = dspy.OutputField() + answer: str = dspy.OutputField() + + adapter = dspy.ChatAdapter() + + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices( + message=Message( + content="[[ ## answer ## ]]\nParis\n[[ ## completion ## ]]", + reasoning_content="Step-by-step thinking about the capital of France", + ), + ) + ], + model="anthropic/claude-3-7-sonnet-20250219", + ) + modified_signature = adapter._call_preprocess( + dspy.LM(model="anthropic/claude-3-7-sonnet-20250219", reasoning_effort="low", cache=False), + {}, + MySignature, + {"question": "What is the capital of France?"}, + ) + assert "reasoning" not in modified_signature.output_fields + + result = adapter( + dspy.LM(model="anthropic/claude-3-7-sonnet-20250219", reasoning_effort="low", cache=False), + {}, + MySignature, + [], + {"question": "What is the capital of France?"}, + ) + assert result[0]["reasoning"] == dspy.Reasoning(content="Step-by-step thinking about the capital of France") + + def test_format_system_message(): class MySignature(dspy.Signature): """Answer the question with multiple answers and scores""" diff --git a/tests/adapters/test_citation.py b/tests/adapters/test_citation.py index d1fdc8d084..3c2a7f75e6 100644 --- a/tests/adapters/test_citation.py +++ b/tests/adapters/test_citation.py @@ -155,7 +155,8 @@ class CitationSignature(Signature): CitationSignature.delete("citations"), CitationSignature, outputs, - dspy.LM(model="claude-3-5-sonnet-20241022") + dspy.LM(model="anthropic/claude-3-5-sonnet-20241022"), + lm_kwargs={}, ) assert len(result) == 1 diff --git a/tests/adapters/test_json_adapter.py b/tests/adapters/test_json_adapter.py index 2acd4e63c3..373c01c67e 100644 --- a/tests/adapters/test_json_adapter.py +++ b/tests/adapters/test_json_adapter.py @@ -887,6 +887,44 @@ def get_weather(city: str) -> str: assert call_kwargs["response_format"] == {"type": "json_object"} +def test_json_adapter_native_reasoning(): + class MySignature(dspy.Signature): + question: str = dspy.InputField() + reasoning: dspy.Reasoning = dspy.OutputField() + answer: str = dspy.OutputField() + + adapter = dspy.JSONAdapter() + + with mock.patch("litellm.completion") as mock_completion: + mock_completion.return_value = ModelResponse( + choices=[ + Choices( + message=Message( + content="{'answer': 'Paris'}", + reasoning_content="Step-by-step thinking about the capital of France", + ), + ) + ], + model="anthropic/claude-3-7-sonnet-20250219", + ) + modified_signature = adapter._call_preprocess( + dspy.LM(model="anthropic/claude-3-7-sonnet-20250219", reasoning_effort="low", cache=False), + {}, + MySignature, + {"question": "What is the capital of France?"}, + ) + assert "reasoning" not in modified_signature.output_fields + + result = adapter( + dspy.LM(model="anthropic/claude-3-7-sonnet-20250219", reasoning_effort="low", cache=False), + {}, + MySignature, + [], + {"question": "What is the capital of France?"}, + ) + assert result[0]["reasoning"] == dspy.Reasoning(content="Step-by-step thinking about the capital of France") + + def test_json_adapter_with_responses_api(): class TestSignature(dspy.Signature): question: str = dspy.InputField() diff --git a/tests/clients/test_lm.py b/tests/clients/test_lm.py index 3a228286f2..65bf63496c 100644 --- a/tests/clients/test_lm.py +++ b/tests/clients/test_lm.py @@ -623,6 +623,67 @@ def test_responses_api_tool_calls(litellm_test_server): assert dspy_responses.call_args.kwargs["model"] == "openai/dspy-test-model" +def test_reasoning_effort_responses_api(): + """Test that reasoning_effort gets normalized to reasoning format for Responses API.""" + with mock.patch("litellm.responses") as mock_responses: + # OpenAI model with Responses API - should normalize + lm = dspy.LM( + model="openai/gpt-5", model_type="responses", reasoning_effort="low", max_tokens=16000, temperature=1.0 + ) + lm("openai query") + call_kwargs = mock_responses.call_args.kwargs + assert "reasoning_effort" not in call_kwargs + assert call_kwargs["reasoning"] == {"effort": "low", "summary": "auto"} + + +def test_call_reasoning_model_with_chat_api(): + """Test that Chat API properly handles reasoning models and returns data in correct format.""" + # Create message with reasoning_content attribute + message = Message(content="The answer is 4", role="assistant") + # Add reasoning_content attribute + message.reasoning_content = "Step 1: I need to add 2 + 2\nStep 2: 2 + 2 = 4\nTherefore, the answer is 4" + + # Create choice with the message + mock_choice = Choices(message=message) + + # Mock response with reasoning content for chat completion + mock_response = ModelResponse( + choices=[mock_choice], + model="anthropic/claude-3-7-sonnet-20250219", + usage={"prompt_tokens": 10, "completion_tokens": 20, "total_tokens": 30}, + ) + + with mock.patch("litellm.completion", return_value=mock_response) as mock_completion: + with mock.patch("litellm.supports_reasoning", return_value=True): + # Create reasoning model with chat API + lm = dspy.LM( + model="anthropic/claude-3-7-sonnet-20250219", + model_type="chat", + temperature=1.0, + max_tokens=16000, + reasoning_effort="low", + cache=False, + ) + + # Test the call + result = lm("What is 2 + 2?") + + # Verify the response format + assert isinstance(result, list) + assert len(result) == 1 + assert isinstance(result[0], dict) + assert "text" in result[0] + assert "reasoning_content" in result[0] + assert result[0]["text"] == "The answer is 4" + assert "Step 1" in result[0]["reasoning_content"] + + # Verify mock was called with correct parameters + mock_completion.assert_called_once() + call_kwargs = mock_completion.call_args.kwargs + assert call_kwargs["model"] == "anthropic/claude-3-7-sonnet-20250219" + assert call_kwargs["reasoning_effort"] == "low" + + def test_api_key_not_saved_in_json(): lm = dspy.LM( model="openai/gpt-4o-mini", diff --git a/tests/streaming/test_streaming.py b/tests/streaming/test_streaming.py index a1e75be5dd..ae5ea57843 100644 --- a/tests/streaming/test_streaming.py +++ b/tests/streaming/test_streaming.py @@ -974,6 +974,10 @@ class CustomType(Type): def is_streamable(cls) -> bool: return True + @classmethod + def adapt_to_native_lm_feature(cls, signature, field_name, lm, lm_kwargs): + return signature.delete(field_name) + @classmethod def parse_stream_chunk(cls, chunk): return CustomType(message=chunk.choices[0].delta.content) @@ -1515,3 +1519,294 @@ def test_stream_listener_could_form_end_identifier_xml_adapter(): # Should return False for text that cannot form the pattern assert listener._could_form_end_identifier("hello world", "XMLAdapter") is False assert listener._could_form_end_identifier("some text", "XMLAdapter") is False + + +@pytest.mark.anyio +async def test_streaming_reasoning_model(): + """Test streaming behavior for reasoning-capable models using dspy.Reasoning. + + This test verifies that: + 1. Reasoning content is extracted from delta.reasoning_content in stream chunks + 2. Reasoning chunks are streamed independently from regular content + 3. The final prediction contains a Reasoning object with the full reasoning content + """ + + class ReasoningSignature(dspy.Signature): + question: str = dspy.InputField() + reasoning: dspy.Reasoning = dspy.OutputField() + answer: str = dspy.OutputField() + + class MyProgram(dspy.Module): + def __init__(self): + super().__init__() + self.predict = dspy.Predict(ReasoningSignature) + + def forward(self, question, **kwargs): + return self.predict(question=question, **kwargs) + + async def reasoning_stream(*args, **kwargs): + """Simulate streaming from a reasoning model like Claude 3.7 Sonnet""" + # Reasoning content comes through delta.reasoning_content + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[ + StreamingChoices(delta=Delta(reasoning_content="First, let's think about this problem step by step. ")) + ], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(reasoning_content="We need to consider the context of a kitchen. "))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[ + StreamingChoices( + delta=Delta(reasoning_content="The chicken likely wants to reach something on the other side.") + ) + ], + ) + # Regular answer content comes through delta.content + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content="[[ ## answer ## ]]\n"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content="To"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content=" get"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content=" to"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content=" the"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content=" other"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content=" side"))], + ) + yield ModelResponseStream( + model="anthropic/claude-3-7-sonnet-20250219", + choices=[StreamingChoices(delta=Delta(content="!\n\n[[ ## completed ## ]]"))], + ) + + with mock.patch("litellm.acompletion", side_effect=reasoning_stream): + with mock.patch("litellm.supports_reasoning", return_value=True): + program = dspy.streamify( + MyProgram(), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="reasoning"), + dspy.streaming.StreamListener(signature_field_name="answer"), + ], + ) + with dspy.context( + lm=dspy.LM("anthropic/claude-3-7-sonnet-20250219", cache=False), + adapter=dspy.ChatAdapter(native_response_types=[dspy.Reasoning]), + ): + output = program(question="Why did a chicken cross the kitchen?") + reasoning_chunks = [] + answer_chunks = [] + final_prediction = None + async for value in output: + if isinstance(value, dspy.streaming.StreamResponse): + if value.signature_field_name == "reasoning": + reasoning_chunks.append(value) + elif value.signature_field_name == "answer": + answer_chunks.append(value) + elif isinstance(value, dspy.Prediction): + final_prediction = value + + # Verify reasoning chunks were streamed + assert len(reasoning_chunks) == 3 + assert reasoning_chunks[0].chunk == "First, let's think about this problem step by step. " + assert reasoning_chunks[1].chunk == "We need to consider the context of a kitchen. " + assert reasoning_chunks[2].chunk == "The chicken likely wants to reach something on the other side." + + # Verify answer chunks were streamed + assert len(answer_chunks) > 0 + assert answer_chunks[0].chunk == "To" + full_answer = "".join([chunk.chunk for chunk in answer_chunks]) + assert full_answer == "To get to the other side!" + + # Verify final prediction has Reasoning object + assert final_prediction is not None + assert hasattr(final_prediction, "reasoning") + assert isinstance(final_prediction.reasoning, dspy.Reasoning) + expected_reasoning = ( + "First, let's think about this problem step by step. " + "We need to consider the context of a kitchen. " + "The chicken likely wants to reach something on the other side." + ) + assert final_prediction.reasoning.content == expected_reasoning + + +@pytest.mark.anyio +async def test_streaming_reasoning_fallback(): + """Test fallback behavior for non-reasoning models using dspy.Reasoning. + + This test verifies that: + 1. For non-reasoning models, reasoning is treated as a regular string field + 2. Reasoning content is streamed through regular adapter parsing (not reasoning_content) + 3. The Reasoning object is created from the parsed string content + 4. Streaming behavior is identical to regular string fields + """ + + class ReasoningSignature(dspy.Signature): + question: str = dspy.InputField() + reasoning: dspy.Reasoning = dspy.OutputField() + answer: str = dspy.OutputField() + + class MyProgram(dspy.Module): + def __init__(self): + super().__init__() + self.predict = dspy.Predict(ReasoningSignature) + + def forward(self, question, **kwargs): + return self.predict(question=question, **kwargs) + + async def non_reasoning_stream(*args, **kwargs): + """Simulate streaming from a non-reasoning model like GPT-4o-mini. + + The reasoning field is formatted by the adapter as a regular field, + and content comes through delta.content (not reasoning_content). + """ + # Reasoning field marker (ChatAdapter format) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="[[ ## reasoning ## ]]\n"))], + ) + # Reasoning content as regular text + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="Let"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="'s"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" think"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" step"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" by"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" step"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" about"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" this"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" question"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="."))], + ) + # Answer field marker + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## answer ## ]]\n"))], + ) + # Answer content + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="To"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" get"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" to"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" the"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" other"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content=" side"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="!"))], + ) + yield ModelResponseStream( + model="gpt-4o-mini", + choices=[StreamingChoices(delta=Delta(content="\n\n[[ ## completed ## ]]"))], + ) + + with mock.patch("litellm.acompletion", side_effect=non_reasoning_stream): + with mock.patch("litellm.supports_reasoning", return_value=False): + program = dspy.streamify( + MyProgram(), + stream_listeners=[ + dspy.streaming.StreamListener(signature_field_name="reasoning"), + dspy.streaming.StreamListener(signature_field_name="answer"), + ], + ) + with dspy.context( + lm=dspy.LM("openai/gpt-4o-mini", cache=False), + adapter=dspy.ChatAdapter(), + ): + output = program(question="Why did a chicken cross the kitchen?") + reasoning_chunks = [] + answer_chunks = [] + final_prediction = None + async for value in output: + if isinstance(value, dspy.streaming.StreamResponse): + if value.signature_field_name == "reasoning": + reasoning_chunks.append(value) + elif value.signature_field_name == "answer": + answer_chunks.append(value) + elif isinstance(value, dspy.Prediction): + final_prediction = value + + # Verify reasoning was streamed as regular text + assert len(reasoning_chunks) > 0 + assert reasoning_chunks[0].chunk == "Let" + assert reasoning_chunks[1].chunk == "'s" + full_reasoning = "".join([chunk.chunk for chunk in reasoning_chunks]) + assert full_reasoning == "Let's think step by step about this question." + + # Verify answer chunks were streamed + assert len(answer_chunks) > 0 + assert answer_chunks[0].chunk == "To" + full_answer = "".join([chunk.chunk for chunk in answer_chunks]) + assert full_answer == "To get to the other side!" + + # Verify final prediction has Reasoning object created from string + assert final_prediction is not None + assert hasattr(final_prediction, "reasoning") + assert isinstance(final_prediction.reasoning, dspy.Reasoning) + assert final_prediction.reasoning.content == "Let's think step by step about this question." + # Verify Reasoning object is str-like + assert str(final_prediction.reasoning) == "Let's think step by step about this question."