Skip to content

Commit c2d639f

Browse files
authored
Merge pull request #574 from thomasahle/main
New Typed Signature Optimizer
2 parents b2816c4 + 26bb358 commit c2d639f

File tree

17 files changed

+1374
-189
lines changed

17 files changed

+1374
-189
lines changed

dsp/templates/template_v2.py

Lines changed: 32 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -72,6 +72,8 @@ def query(self, example: Example, is_demo: bool = False) -> str:
7272
"""Retrieves the input variables from the example and formats them into a query string."""
7373
result: list[str] = []
7474

75+
# If not a demo, find the last field that doesn't have a value set in `example` and set it to ""
76+
# This creates the "Output:" prefix at the end of the prompt.
7577
if not is_demo:
7678
has_value = [
7779
field.input_variable in example
@@ -80,40 +82,40 @@ def query(self, example: Example, is_demo: bool = False) -> str:
8082
for field in self.fields
8183
]
8284

83-
for i in range(1, len(has_value)):
84-
if has_value[i - 1] and not any(has_value[i:]):
85-
example[self.fields[i].input_variable] = ""
86-
break
85+
# If there are no inputs, set the first field to ""
86+
if not any(has_value):
87+
example[self.fields[0].input_variable] = ""
88+
# Otherwise find the first field without a value.
89+
else:
90+
for i in range(1, len(has_value)):
91+
if has_value[i - 1] and not any(has_value[i:]):
92+
example[self.fields[i].input_variable] = ""
93+
break
8794

8895
for field in self.fields:
89-
if (
90-
field.input_variable in example
91-
and example[field.input_variable] is not None
92-
):
96+
if field.input_variable in example and example[field.input_variable] is not None:
9397
if field.input_variable in self.format_handlers:
9498
format_handler = self.format_handlers[field.input_variable]
9599
else:
100+
96101
def format_handler(x):
97102
assert type(x) == str, f"Need format_handler for {field.input_variable} of type {type(x)}"
98103
return " ".join(x.split())
99104

100105
formatted_value = format_handler(example[field.input_variable])
101-
separator = '\n' if field.separator == ' ' and '\n' in formatted_value else field.separator
106+
separator = "\n" if field.separator == " " and "\n" in formatted_value else field.separator
102107

103108
result.append(
104109
f"{field.name}{separator}{formatted_value}",
105110
)
106111

107-
if self._has_augmented_guidelines() and (example.get('augmented', False)):
112+
if self._has_augmented_guidelines() and (example.get("augmented", False)):
108113
return "\n\n".join([r for r in result if r])
109114
return "\n".join([r for r in result if r])
110115

111116
def guidelines(self, show_guidelines=True) -> str:
112117
"""Returns the task guidelines as described in the lm prompt"""
113-
if (not show_guidelines) or (
114-
hasattr(dsp.settings, "show_guidelines")
115-
and not dsp.settings.show_guidelines
116-
):
118+
if (not show_guidelines) or (hasattr(dsp.settings, "show_guidelines") and not dsp.settings.show_guidelines):
117119
return ""
118120

119121
result = "Follow the following format.\n\n"
@@ -128,11 +130,13 @@ def guidelines(self, show_guidelines=True) -> str:
128130

129131
def _has_augmented_guidelines(self):
130132
return len(self.fields) > 3 or any(
131-
("\n" in field.separator) or ('\n' in field.description) for field in self.fields
133+
("\n" in field.separator) or ("\n" in field.description) for field in self.fields
132134
)
133135

134136
def extract(
135-
self, example: Union[Example, dict[str, Any]], raw_pred: str,
137+
self,
138+
example: Union[Example, dict[str, Any]],
139+
raw_pred: str,
136140
) -> Example:
137141
"""Extracts the answer from the LM raw prediction using the template structure
138142
@@ -149,10 +153,7 @@ def extract(
149153

150154
idx = 0
151155
while idx < len(self.fields):
152-
if (
153-
self.fields[idx].input_variable not in example
154-
or example[self.fields[idx].input_variable] is None
155-
):
156+
if self.fields[idx].input_variable not in example or example[self.fields[idx].input_variable] is None:
156157
break
157158
idx += 1
158159

@@ -166,16 +167,16 @@ def extract(
166167

167168
if offset >= 0:
168169
if dspy.settings.release >= 20231003:
169-
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip('---').strip()
170-
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip('---').strip()
170+
example[self.fields[idx].output_variable] = raw_pred[:offset].strip().rstrip("---").strip()
171+
raw_pred = raw_pred[offset + len(next_field_name) :].strip().rstrip("---").strip()
171172
else:
172173
example[self.fields[idx].output_variable] = raw_pred[:offset].strip()
173174
raw_pred = raw_pred[offset + len(next_field_name) :].strip()
174175

175176
idx += 1
176177
else:
177178
if dspy.settings.release >= 20231003:
178-
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
179+
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
179180
else:
180181
example[self.fields[idx].output_variable] = raw_pred.strip()
181182

@@ -187,7 +188,7 @@ def extract(
187188
assert idx == len(self.fields) - 1, (idx, len(self.fields))
188189

189190
if dspy.settings.release >= 20231003:
190-
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip('---').strip()
191+
example[self.fields[idx].output_variable] = raw_pred.strip().rstrip("---").strip()
191192
else:
192193
example[self.fields[idx].output_variable] = raw_pred.strip()
193194

@@ -198,7 +199,7 @@ def extract(
198199
def __call__(self, example, show_guidelines=True) -> str:
199200
example = dsp.Example(example)
200201

201-
if hasattr(dsp.settings, 'query_only') and dsp.settings.query_only:
202+
if hasattr(dsp.settings, "query_only") and dsp.settings.query_only:
202203
return self.query(example)
203204

204205
# The training data should not contain the output variable
@@ -209,29 +210,20 @@ def __call__(self, example, show_guidelines=True) -> str:
209210
self.query(demo, is_demo=True)
210211
for demo in example.demos
211212
if (
212-
(not demo.get('augmented', False))
213+
(not demo.get("augmented", False))
213214
and ( # validate that the training example has the same primitive input var as the template
214-
self.fields[-1].input_variable in demo
215-
and demo[self.fields[-1].input_variable] is not None
215+
self.fields[-1].input_variable in demo and demo[self.fields[-1].input_variable] is not None
216216
)
217217
)
218218
]
219219

220-
ademos = [
221-
self.query(demo, is_demo=True)
222-
for demo in example.demos
223-
if demo.get('augmented', False)
224-
]
220+
ademos = [self.query(demo, is_demo=True) for demo in example.demos if demo.get("augmented", False)]
225221

226222
# Move the rdemos to ademos if rdemo has all the fields filled in
227223
rdemos_ = []
228224
new_ademos = []
229225
for rdemo in rdemos:
230-
if all(
231-
(field.name in rdemo)
232-
for field in self.fields
233-
if field.input_variable in example
234-
):
226+
if all((field.name in rdemo) for field in self.fields if field.input_variable in example):
235227
import dspy
236228

237229
if dspy.settings.release >= 20230928:
@@ -244,7 +236,6 @@ def __call__(self, example, show_guidelines=True) -> str:
244236
ademos = new_ademos + ademos
245237
rdemos = rdemos_
246238

247-
248239
long_query = self._has_augmented_guidelines()
249240

250241
if long_query:
@@ -253,10 +244,10 @@ def __call__(self, example, show_guidelines=True) -> str:
253244
query = self.query(example)
254245

255246
# if it has more lines than fields
256-
if len(query.split('\n')) > len(self.fields):
247+
if len(query.split("\n")) > len(self.fields):
257248
long_query = True
258249

259-
if not example.get('augmented', False):
250+
if not example.get("augmented", False):
260251
example["augmented"] = True
261252
query = self.query(example)
262253

dspy/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,3 @@
1-
# from .evaluation import *
2-
# FIXME:
31
import dsp
42
from dsp.modules.hf_client import ChatModuleClient, HFClientSGLang, HFClientVLLM, HFServerTGI
53

@@ -8,6 +6,9 @@
86
from .retrieve import *
97
from .signatures import *
108

9+
# Functional must be imported after primitives, predict and signatures
10+
from .functional import * # isort: skip
11+
1112
settings = dsp.settings
1213

1314
AzureOpenAI = dsp.AzureOpenAI

0 commit comments

Comments
 (0)