@@ -72,6 +72,8 @@ def query(self, example: Example, is_demo: bool = False) -> str:
7272 """Retrieves the input variables from the example and formats them into a query string."""
7373 result : list [str ] = []
7474
75+ # If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
76+ # This creates the "Output:" prefix at the end of the prompt.
7577 if not is_demo :
7678 has_value = [
7779 field .input_variable in example
@@ -80,40 +82,40 @@ def query(self, example: Example, is_demo: bool = False) -> str:
8082 for field in self .fields
8183 ]
8284
83- for i in range (1 , len (has_value )):
84- if has_value [i - 1 ] and not any (has_value [i :]):
85- example [self .fields [i ].input_variable ] = ""
86- break
85+ # If there are no inputs, set the first field to ""
86+ if not any (has_value ):
87+ example [self .fields [0 ].input_variable ] = ""
88+ # Otherwise find the first field without a value.
89+ else :
90+ for i in range (1 , len (has_value )):
91+ if has_value [i - 1 ] and not any (has_value [i :]):
92+ example [self .fields [i ].input_variable ] = ""
93+ break
8794
8895 for field in self .fields :
89- if (
90- field .input_variable in example
91- and example [field .input_variable ] is not None
92- ):
96+ if field .input_variable in example and example [field .input_variable ] is not None :
9397 if field .input_variable in self .format_handlers :
9498 format_handler = self .format_handlers [field .input_variable ]
9599 else :
100+
96101 def format_handler (x ):
97102 assert type (x ) == str , f"Need format_handler for { field .input_variable } of type { type (x )} "
98103 return " " .join (x .split ())
99104
100105 formatted_value = format_handler (example [field .input_variable ])
101- separator = ' \n ' if field .separator == ' ' and ' \n ' in formatted_value else field .separator
106+ separator = " \n " if field .separator == " " and " \n " in formatted_value else field .separator
102107
103108 result .append (
104109 f"{ field .name } { separator } { formatted_value } " ,
105110 )
106111
107- if self ._has_augmented_guidelines () and (example .get (' augmented' , False )):
112+ if self ._has_augmented_guidelines () and (example .get (" augmented" , False )):
108113 return "\n \n " .join ([r for r in result if r ])
109114 return "\n " .join ([r for r in result if r ])
110115
111116 def guidelines (self , show_guidelines = True ) -> str :
112117 """Returns the task guidelines as described in the lm prompt"""
113- if (not show_guidelines ) or (
114- hasattr (dsp .settings , "show_guidelines" )
115- and not dsp .settings .show_guidelines
116- ):
118+ if (not show_guidelines ) or (hasattr (dsp .settings , "show_guidelines" ) and not dsp .settings .show_guidelines ):
117119 return ""
118120
119121 result = "Follow the following format.\n \n "
@@ -128,11 +130,13 @@ def guidelines(self, show_guidelines=True) -> str:
128130
129131 def _has_augmented_guidelines (self ):
130132 return len (self .fields ) > 3 or any (
131- ("\n " in field .separator ) or (' \n ' in field .description ) for field in self .fields
133+ ("\n " in field .separator ) or (" \n " in field .description ) for field in self .fields
132134 )
133135
134136 def extract (
135- self , example : Union [Example , dict [str , Any ]], raw_pred : str ,
137+ self ,
138+ example : Union [Example , dict [str , Any ]],
139+ raw_pred : str ,
136140 ) -> Example :
137141 """Extracts the answer from the LM raw prediction using the template structure
138142
@@ -149,10 +153,7 @@ def extract(
149153
150154 idx = 0
151155 while idx < len (self .fields ):
152- if (
153- self .fields [idx ].input_variable not in example
154- or example [self .fields [idx ].input_variable ] is None
155- ):
156+ if self .fields [idx ].input_variable not in example or example [self .fields [idx ].input_variable ] is None :
156157 break
157158 idx += 1
158159
@@ -166,16 +167,16 @@ def extract(
166167
167168 if offset >= 0 :
168169 if dspy .settings .release >= 20231003 :
169- example [self .fields [idx ].output_variable ] = raw_pred [:offset ].strip ().rstrip (' ---' ).strip ()
170- raw_pred = raw_pred [offset + len (next_field_name ) :].strip ().rstrip (' ---' ).strip ()
170+ example [self .fields [idx ].output_variable ] = raw_pred [:offset ].strip ().rstrip (" ---" ).strip ()
171+ raw_pred = raw_pred [offset + len (next_field_name ) :].strip ().rstrip (" ---" ).strip ()
171172 else :
172173 example [self .fields [idx ].output_variable ] = raw_pred [:offset ].strip ()
173174 raw_pred = raw_pred [offset + len (next_field_name ) :].strip ()
174175
175176 idx += 1
176177 else :
177178 if dspy .settings .release >= 20231003 :
178- example [self .fields [idx ].output_variable ] = raw_pred .strip ().rstrip (' ---' ).strip ()
179+ example [self .fields [idx ].output_variable ] = raw_pred .strip ().rstrip (" ---" ).strip ()
179180 else :
180181 example [self .fields [idx ].output_variable ] = raw_pred .strip ()
181182
@@ -187,7 +188,7 @@ def extract(
187188 assert idx == len (self .fields ) - 1 , (idx , len (self .fields ))
188189
189190 if dspy .settings .release >= 20231003 :
190- example [self .fields [idx ].output_variable ] = raw_pred .strip ().rstrip (' ---' ).strip ()
191+ example [self .fields [idx ].output_variable ] = raw_pred .strip ().rstrip (" ---" ).strip ()
191192 else :
192193 example [self .fields [idx ].output_variable ] = raw_pred .strip ()
193194
@@ -198,7 +199,7 @@ def extract(
198199 def __call__ (self , example , show_guidelines = True ) -> str :
199200 example = dsp .Example (example )
200201
201- if hasattr (dsp .settings , ' query_only' ) and dsp .settings .query_only :
202+ if hasattr (dsp .settings , " query_only" ) and dsp .settings .query_only :
202203 return self .query (example )
203204
204205 # The training data should not contain the output variable
@@ -209,29 +210,20 @@ def __call__(self, example, show_guidelines=True) -> str:
209210 self .query (demo , is_demo = True )
210211 for demo in example .demos
211212 if (
212- (not demo .get (' augmented' , False ))
213+ (not demo .get (" augmented" , False ))
213214 and ( # validate that the training example has the same primitive input var as the template
214- self .fields [- 1 ].input_variable in demo
215- and demo [self .fields [- 1 ].input_variable ] is not None
215+ self .fields [- 1 ].input_variable in demo and demo [self .fields [- 1 ].input_variable ] is not None
216216 )
217217 )
218218 ]
219219
220- ademos = [
221- self .query (demo , is_demo = True )
222- for demo in example .demos
223- if demo .get ('augmented' , False )
224- ]
220+ ademos = [self .query (demo , is_demo = True ) for demo in example .demos if demo .get ("augmented" , False )]
225221
226222 # Move the rdemos to ademos if rdemo has all the fields filled in
227223 rdemos_ = []
228224 new_ademos = []
229225 for rdemo in rdemos :
230- if all (
231- (field .name in rdemo )
232- for field in self .fields
233- if field .input_variable in example
234- ):
226+ if all ((field .name in rdemo ) for field in self .fields if field .input_variable in example ):
235227 import dspy
236228
237229 if dspy .settings .release >= 20230928 :
@@ -244,7 +236,6 @@ def __call__(self, example, show_guidelines=True) -> str:
244236 ademos = new_ademos + ademos
245237 rdemos = rdemos_
246238
247-
248239 long_query = self ._has_augmented_guidelines ()
249240
250241 if long_query :
@@ -253,10 +244,10 @@ def __call__(self, example, show_guidelines=True) -> str:
253244 query = self .query (example )
254245
255246 # if it has more lines than fields
256- if len (query .split (' \n ' )) > len (self .fields ):
247+ if len (query .split (" \n " )) > len (self .fields ):
257248 long_query = True
258249
259- if not example .get (' augmented' , False ):
250+ if not example .get (" augmented" , False ):
260251 example ["augmented" ] = True
261252 query = self .query (example )
262253
0 commit comments