Skip to content

Commit ade853a

Browse files
authored
Merge pull request #750 from stanfordnlp/custom-rationale-typed-cot
Add Custom Rationale Type option in TypedChainOfThought Module
2 parents 6d8968a + 8bf4aa8 commit ade853a

File tree

2 files changed

+51
-5
lines changed

2 files changed

+51
-5
lines changed

dspy/functional/functional.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,17 +67,21 @@ def __init__(self):
6767
self.__dict__[name] = attr.copy()
6868

6969

70-
def TypedChainOfThought(signature, instructions=None, *, max_retries=3) -> dspy.Module: # noqa: N802
70+
def TypedChainOfThought(signature, instructions=None, reasoning=None, *, max_retries=3) -> dspy.Module: # noqa: N802
7171
"""Just like TypedPredictor, but adds a ChainOfThought OutputField."""
7272
signature = ensure_signature(signature, instructions)
7373
output_keys = ", ".join(signature.output_fields.keys())
74+
75+
DEFAULT_RATIONALE = dspy.OutputField(
76+
prefix="Reasoning: Let's think step by step in order to",
77+
desc="${produce the " + output_keys + "}. We ...",
78+
)
79+
reasoning = reasoning or DEFAULT_RATIONALE
80+
7481
return TypedPredictor(
7582
signature.prepend(
7683
"reasoning",
77-
dspy.OutputField(
78-
prefix="Reasoning: Let's think step by step in order to",
79-
desc="${produce the " + output_keys + "}. We ...",
80-
),
84+
reasoning,
8185
),
8286
max_retries=max_retries,
8387
)

tests/functional/test_functional.py

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -619,6 +619,48 @@ class ScoredSignature(dspy.Signature):
619619
Proposed Signature: Output""")
620620

621621

622+
def test_custom_reasoning_field():
623+
class Question(pydantic.BaseModel):
624+
value: str
625+
626+
class QuestionSignature(dspy.Signature):
627+
topic: str = dspy.InputField()
628+
question: Question = dspy.OutputField()
629+
630+
reasoning = dspy.OutputField(
631+
prefix="Custom Reasoning: Let's break this down. To generate a question about",
632+
desc="${topic}, we should ...",
633+
)
634+
635+
program = TypedChainOfThought(QuestionSignature, reasoning=reasoning)
636+
637+
expected = "What is the speed of light?"
638+
lm = DummyLM(["Thoughts", f'{{"value": "{expected}"}}'])
639+
dspy.settings.configure(lm=lm)
640+
641+
output = program(topic="Physics")
642+
643+
assert isinstance(output.question, Question)
644+
assert output.question.value == expected
645+
646+
assert lm.get_convo(-1) == textwrap.dedent("""\
647+
Given the fields `topic`, produce the fields `question`.
648+
649+
---
650+
651+
Follow the following format.
652+
653+
Topic: ${topic}
654+
Custom Reasoning: Let's break this down. To generate a question about ${topic}, we should ...
655+
Question: ${question}. Respond with a single JSON object. JSON Schema: {"properties": {"value": {"title": "Value", "type": "string"}}, "required": ["value"], "title": "Question", "type": "object"}
656+
657+
---
658+
659+
Topic: Physics
660+
Custom Reasoning: Let's break this down. To generate a question about Thoughts
661+
Question: {"value": "What is the speed of light?"}""")
662+
663+
622664
def test_generic_signature():
623665
T = TypeVar("T")
624666

0 commit comments

Comments
 (0)