Skip to content

Commit 5f0aa61

Browse files
Merge pull request #147 from stanfordnlp/PoT
Program of Thought module
2 parents d83255d + 329bc1b commit 5f0aa61

File tree

5 files changed

+791
-9
lines changed

5 files changed

+791
-9
lines changed

dspy/predict/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,4 @@
44
from .chain_of_thought_with_hint import ChainOfThoughtWithHint
55
from .react import ReAct
66
from .aggregation import majority
7+
from .program_of_thought import ProgramOfThought

dspy/predict/program_of_thought.py

Lines changed: 114 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,114 @@
1+
import dsp
2+
import dspy
3+
from ..primitives.program import Module
4+
from ..primitives.python_interpreter import CodePrompt, PythonInterpreter
5+
import re
6+
7+
class ProgramOfThought(Module):
8+
def __init__(self, signature, max_iters=3):
9+
super().__init__()
10+
self.signature = signature = dspy.Predict(signature).signature
11+
self.max_iters = max_iters
12+
13+
self.input_fields = signature.input_fields()
14+
self.output_fields = signature.output_fields()
15+
16+
inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()])
17+
outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()])
18+
19+
assert len(self.output_fields) == 1, "PoT only supports one output field."
20+
21+
instr = []
22+
instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.")
23+
instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.")
24+
instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.")
25+
instr = '\n'.join(instr)
26+
27+
self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate')))
28+
self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate')))
29+
self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer')))
30+
31+
def _generate_signature(self, mode):
32+
signature_dict = dict(self.input_fields)
33+
fields_for_mode = {
34+
'generate': {
35+
'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
36+
},
37+
'regenerate': {
38+
'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str),
39+
'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"),
40+
'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str)
41+
},
42+
'answer': {
43+
'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str),
44+
'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"),
45+
'answer': self.signature.kwargs["answer"]
46+
}
47+
}
48+
signature_dict.update(fields_for_mode[mode])
49+
return signature_dict
50+
51+
def _generate_instruction(self, mode):
52+
mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)])
53+
mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)])
54+
if mode == 'generate':
55+
instr = [
56+
f"You will be given {mode_inputs} and you will respond with {mode_outputs}.",
57+
f"Generating executable Python code that programmatically computes the correct {mode_outputs}.",
58+
f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}."
59+
]
60+
elif mode == 'regenerate':
61+
instr = [
62+
f"You are given {mode_inputs} due to an error in previous code.",
63+
f"Your task is to correct the error and provide the new {mode_outputs}."
64+
]
65+
else: # mode == 'answer'
66+
instr = [
67+
f"Given the final code {mode_inputs}, provide the final {mode_outputs}."
68+
]
69+
70+
return '\n'.join(instr)
71+
72+
def parse_code(self, code_data):
73+
code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0]
74+
code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL)
75+
code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n')
76+
if not code_block:
77+
return code, "Error: Empty code after parsing."
78+
if "\n" not in code_block and code_block.count('=') > 1:
79+
return code, "Error: Code format is not correct."
80+
lines = code_block.split('\n')
81+
last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip())
82+
if last_line_match and len(lines) > 1:
83+
code_block += '\n' + last_line_match.group(1)
84+
else:
85+
code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block)
86+
code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block)
87+
return code_block, None
88+
89+
def execute_code(self, code):
90+
if not code:
91+
return code, None, 'Error: Empty code before execution.'
92+
code_prompt = CodePrompt(code, code_type="python")
93+
interpreter = PythonInterpreter(action_space={"print": print})
94+
try:
95+
output = str(code_prompt.execute(interpreter=interpreter)[0])
96+
return code, output, None
97+
except Exception as e:
98+
return code, None, str(e)
99+
100+
def forward(self, **kwargs):
101+
code_data = self.code_generate(question=kwargs["question"])
102+
parsed_code, error = self.parse_code(code_data)
103+
code, output, error = self.execute_code(parsed_code)
104+
hop = 0
105+
while hop < self.max_iters and error:
106+
print('Error in code execution')
107+
code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error)
108+
parsed_code, error = self.parse_code(code_data)
109+
hop += 1
110+
if hop == self.max_iters:
111+
print('Max hops reached. Error persists.')
112+
return None
113+
answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output)
114+
return answer_gen_result

dspy/primitives/__init__.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
11
from .example import *
22
from .program import *
3-
from .prediction import *
3+
from .prediction import *
4+
from .python_interpreter import *

0 commit comments

Comments
 (0)