1+ from collections import defaultdict
12import inspect
23import os
34import openai
78from typing import Annotated , List , Tuple # noqa: UP035
89from dsp .templates import passages2text
910import json
11+ from dspy .primitives .prediction import Prediction
1012
11- from dspy .signatures .signature import ensure_signature
13+ from dspy .signatures .signature import ensure_signature , make_signature
1214
1315
1416MAX_RETRIES = 3
@@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802
7173class TypedPredictor (dspy .Module ):
7274 def __init__ (self , signature ):
7375 super ().__init__ ()
74- self .signature = signature
76+ self .signature = ensure_signature ( signature )
7577 self .predictor = dspy .Predict (signature )
7678
7779 def copy (self ) -> "TypedPredictor" :
@@ -81,7 +83,7 @@ def copy(self) -> "TypedPredictor":
8183 def _make_example (type_ ) -> str :
8284 # Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
8385 json_object = dspy .Predict (
84- dspy . Signature (
86+ make_signature (
8587 "json_schema -> json_object" ,
8688 "Make a very succinct json object that validates with the following schema" ,
8789 ),
@@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature:
127129 name ,
128130 desc = field .json_schema_extra .get ("desc" , "" )
129131 + (
130- ". Respond with a single JSON object. JSON Schema: "
131- + json .dumps (type_ .model_json_schema ())
132+ ". Respond with a single JSON object. JSON Schema: " + json .dumps (type_ .model_json_schema ())
132133 ),
133134 format = lambda x , to_json = to_json : (x if isinstance (x , str ) else to_json (x )),
134135 parser = lambda x , from_json = from_json : from_json (_unwrap_json (x )),
@@ -152,13 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction:
152153 for try_i in range (MAX_RETRIES ):
153154 result = self .predictor (** modified_kwargs , new_signature = signature )
154155 errors = {}
155- parsed_results = {}
156+ parsed_results = []
156157 # Parse the outputs
157- for name , field in signature . output_fields . items ( ):
158+ for i , completion in enumerate ( result . completions ):
158159 try :
159- value = getattr (result , name )
160- parser = field .json_schema_extra .get ("parser" , lambda x : x )
161- parsed_results [name ] = parser (value )
160+ parsed = {}
161+ for name , field in signature .output_fields .items ():
162+ value = completion [name ]
163+ parser = field .json_schema_extra .get ("parser" , lambda x : x )
164+ completion [name ] = parser (value )
165+ parsed [name ] = parser (value )
166+ # Instantiate the actual signature with the parsed values.
167+ # This allow pydantic to validate the fields defined in the signature.
168+ _dummy = self .signature (** kwargs , ** parsed )
169+ parsed_results .append (parsed )
162170 except (pydantic .ValidationError , ValueError ) as e :
163171 errors [name ] = _format_error (e )
164172 # If we can, we add an example to the error message
@@ -168,11 +176,14 @@ def forward(self, **kwargs) -> dspy.Prediction:
168176 continue # Only add examples to JSON objects
169177 suffix , current_desc = current_desc [i :], current_desc [:i ]
170178 prefix = "You MUST use this format: "
171- if try_i + 1 < MAX_RETRIES \
172- and prefix not in current_desc \
173- and (example := self ._make_example (field .annotation )):
179+ if (
180+ try_i + 1 < MAX_RETRIES
181+ and prefix not in current_desc
182+ and (example := self ._make_example (field .annotation ))
183+ ):
174184 signature = signature .with_updated_fields (
175- name , desc = current_desc + "\n " + prefix + example + "\n " + suffix ,
185+ name ,
186+ desc = current_desc + "\n " + prefix + example + "\n " + suffix ,
176187 )
177188 if errors :
178189 # Add new fields for each error
@@ -187,11 +198,12 @@ def forward(self, **kwargs) -> dspy.Prediction:
187198 )
188199 else :
189200 # If there are no errors, we return the parsed results
190- for name , value in parsed_results . items ():
191- setattr ( result , name , value )
192- return result
201+ return Prediction . from_completions (
202+ { key : [ r [ key ] for r in parsed_results ] for key in signature . output_fields }
203+ )
193204 raise ValueError (
194- "Too many retries trying to get the correct output format. " + "Try simplifying the requirements." , errors ,
205+ "Too many retries trying to get the correct output format. " + "Try simplifying the requirements." ,
206+ errors ,
195207 )
196208
197209
0 commit comments