|
| 1 | +import dsp |
| 2 | +import dspy |
| 3 | +from ..primitives.program import Module |
| 4 | +from ..primitives.python_interpreter import CodePrompt, PythonInterpreter |
| 5 | +import re |
| 6 | + |
| 7 | +class ProgramOfThought(Module): |
| 8 | + def __init__(self, signature, max_iters=3): |
| 9 | + super().__init__() |
| 10 | + self.signature = signature = dspy.Predict(signature).signature |
| 11 | + self.max_iters = max_iters |
| 12 | + |
| 13 | + self.input_fields = signature.input_fields() |
| 14 | + self.output_fields = signature.output_fields() |
| 15 | + |
| 16 | + inputs_ = ', '.join([f"`{field_name}`" for field_name in self.input_fields.keys()]) |
| 17 | + outputs_ = ', '.join([f"`{field_name}`" for field_name in self.output_fields.keys()]) |
| 18 | + |
| 19 | + assert len(self.output_fields) == 1, "PoT only supports one output field." |
| 20 | + |
| 21 | + instr = [] |
| 22 | + instr.append(f"You will be given {inputs_} and you will respond with {outputs_}.") |
| 23 | + instr.append(f"Generating executable Python code that programmatically computes the correct {outputs_}.") |
| 24 | + instr.append(f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {outputs_}.") |
| 25 | + instr = '\n'.join(instr) |
| 26 | + |
| 27 | + self.code_generate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('generate'), **self._generate_signature('generate'))) |
| 28 | + self.code_regenerate = dspy.ChainOfThought(dsp.Template(self._generate_instruction('regenerate'), **self._generate_signature('regenerate'))) |
| 29 | + self.generate_answer = dspy.ChainOfThought(dsp.Template(self._generate_instruction('answer'), **self._generate_signature('answer'))) |
| 30 | + |
| 31 | + def _generate_signature(self, mode): |
| 32 | + signature_dict = dict(self.input_fields) |
| 33 | + fields_for_mode = { |
| 34 | + 'generate': { |
| 35 | + 'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str) |
| 36 | + }, |
| 37 | + 'regenerate': { |
| 38 | + 'previous_code': dspy.InputField(prefix="Previous Code:", desc="previously-generated python code that errored", format=str), |
| 39 | + 'error': dspy.InputField(prefix="Error:", desc="error message from previously-generated python code"), |
| 40 | + 'generated_code': dspy.OutputField(prefix="Code:", desc="python code that answers the question", format=str) |
| 41 | + }, |
| 42 | + 'answer': { |
| 43 | + 'final_generated_code': dspy.InputField(prefix="Code:", desc="python code that answers the question", format=str), |
| 44 | + 'code_output': dspy.InputField(prefix="Code Output:", desc="output of previously-generated python code"), |
| 45 | + 'answer': self.signature.kwargs["answer"] |
| 46 | + } |
| 47 | + } |
| 48 | + signature_dict.update(fields_for_mode[mode]) |
| 49 | + return signature_dict |
| 50 | + |
| 51 | + def _generate_instruction(self, mode): |
| 52 | + mode_inputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.InputField)]) |
| 53 | + mode_outputs = ', '.join([f"`{field_name}`" for field_name in self._generate_signature(mode).keys() if isinstance(self._generate_signature(mode)[field_name], dspy.OutputField)]) |
| 54 | + if mode == 'generate': |
| 55 | + instr = [ |
| 56 | + f"You will be given {mode_inputs} and you will respond with {mode_outputs}.", |
| 57 | + f"Generating executable Python code that programmatically computes the correct {mode_outputs}.", |
| 58 | + f"After you're done with the computation, make sure the last line in your code evaluates to the correct value for {mode_outputs}." |
| 59 | + ] |
| 60 | + elif mode == 'regenerate': |
| 61 | + instr = [ |
| 62 | + f"You are given {mode_inputs} due to an error in previous code.", |
| 63 | + f"Your task is to correct the error and provide the new {mode_outputs}." |
| 64 | + ] |
| 65 | + else: # mode == 'answer' |
| 66 | + instr = [ |
| 67 | + f"Given the final code {mode_inputs}, provide the final {mode_outputs}." |
| 68 | + ] |
| 69 | + |
| 70 | + return '\n'.join(instr) |
| 71 | + |
| 72 | + def parse_code(self, code_data): |
| 73 | + code = code_data.get('generated_code', '').split('---', 1)[0].split('\n\n\n', 1)[0] |
| 74 | + code_match = re.search(r'```python[ \n](.*?)[ \n]```?', code, re.DOTALL) |
| 75 | + code_block = (code_match.group(1) if code_match else code).replace('\\n', '\n') |
| 76 | + if not code_block: |
| 77 | + return code, "Error: Empty code after parsing." |
| 78 | + if "\n" not in code_block and code_block.count('=') > 1: |
| 79 | + return code, "Error: Code format is not correct." |
| 80 | + lines = code_block.split('\n') |
| 81 | + last_line_match = re.match(r'^(\w+)\s*=', lines[-1].strip()) |
| 82 | + if last_line_match and len(lines) > 1: |
| 83 | + code_block += '\n' + last_line_match.group(1) |
| 84 | + else: |
| 85 | + code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)(?=[a-zA-Z_]\w* *=)', r'\1\n', code_block) |
| 86 | + code_block = re.sub(r'([a-zA-Z_]\w* *=.*?)([a-zA-Z_]\w*)$', r'\1\n\2', code_block) |
| 87 | + return code_block, None |
| 88 | + |
| 89 | + def execute_code(self, code): |
| 90 | + if not code: |
| 91 | + return code, None, 'Error: Empty code before execution.' |
| 92 | + code_prompt = CodePrompt(code, code_type="python") |
| 93 | + interpreter = PythonInterpreter(action_space={"print": print}) |
| 94 | + try: |
| 95 | + output = str(code_prompt.execute(interpreter=interpreter)[0]) |
| 96 | + return code, output, None |
| 97 | + except Exception as e: |
| 98 | + return code, None, str(e) |
| 99 | + |
| 100 | + def forward(self, **kwargs): |
| 101 | + code_data = self.code_generate(question=kwargs["question"]) |
| 102 | + parsed_code, error = self.parse_code(code_data) |
| 103 | + code, output, error = self.execute_code(parsed_code) |
| 104 | + hop = 0 |
| 105 | + while hop < self.max_iters and error: |
| 106 | + print('Error in code execution') |
| 107 | + code_data = self.code_regenerate(question=kwargs["question"], previous_code=code, error=error) |
| 108 | + parsed_code, error = self.parse_code(code_data) |
| 109 | + hop += 1 |
| 110 | + if hop == self.max_iters: |
| 111 | + print('Max hops reached. Error persists.') |
| 112 | + return None |
| 113 | + answer_gen_result = self.generate_answer(question=kwargs["question"], final_generated_code=code, code_output=output) |
| 114 | + return answer_gen_result |
0 commit comments