Skip to content

Commit 71ac993

Browse files
authored
Merge pull request #540 from thomasahle/main
Improvements to Signature
2 parents 16fdf50 + 6da72e3 commit 71ac993

File tree

7 files changed

+384
-220
lines changed

7 files changed

+384
-220
lines changed

dspy/functional/functional.py

Lines changed: 30 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
from collections import defaultdict
12
import inspect
23
import os
34
import openai
@@ -7,8 +8,9 @@
78
from typing import Annotated, List, Tuple # noqa: UP035
89
from dsp.templates import passages2text
910
import json
11+
from dspy.primitives.prediction import Prediction
1012

11-
from dspy.signatures.signature import ensure_signature
13+
from dspy.signatures.signature import ensure_signature, make_signature
1214

1315

1416
MAX_RETRIES = 3
@@ -71,7 +73,7 @@ def TypedChainOfThought(signature) -> dspy.Module: # noqa: N802
7173
class TypedPredictor(dspy.Module):
7274
def __init__(self, signature):
7375
super().__init__()
74-
self.signature = signature
76+
self.signature = ensure_signature(signature)
7577
self.predictor = dspy.Predict(signature)
7678

7779
def copy(self) -> "TypedPredictor":
@@ -81,7 +83,7 @@ def copy(self) -> "TypedPredictor":
8183
def _make_example(type_) -> str:
8284
# Note: DSPy will cache this call so we only pay the first time TypedPredictor is called.
8385
json_object = dspy.Predict(
84-
dspy.Signature(
86+
make_signature(
8587
"json_schema -> json_object",
8688
"Make a very succinct json object that validates with the following schema",
8789
),
@@ -127,8 +129,7 @@ def _prepare_signature(self) -> dspy.Signature:
127129
name,
128130
desc=field.json_schema_extra.get("desc", "")
129131
+ (
130-
". Respond with a single JSON object. JSON Schema: "
131-
+ json.dumps(type_.model_json_schema())
132+
". Respond with a single JSON object. JSON Schema: " + json.dumps(type_.model_json_schema())
132133
),
133134
format=lambda x, to_json=to_json: (x if isinstance(x, str) else to_json(x)),
134135
parser=lambda x, from_json=from_json: from_json(_unwrap_json(x)),
@@ -152,13 +153,20 @@ def forward(self, **kwargs) -> dspy.Prediction:
152153
for try_i in range(MAX_RETRIES):
153154
result = self.predictor(**modified_kwargs, new_signature=signature)
154155
errors = {}
155-
parsed_results = {}
156+
parsed_results = []
156157
# Parse the outputs
157-
for name, field in signature.output_fields.items():
158+
for i, completion in enumerate(result.completions):
158159
try:
159-
value = getattr(result, name)
160-
parser = field.json_schema_extra.get("parser", lambda x: x)
161-
parsed_results[name] = parser(value)
160+
parsed = {}
161+
for name, field in signature.output_fields.items():
162+
value = completion[name]
163+
parser = field.json_schema_extra.get("parser", lambda x: x)
164+
completion[name] = parser(value)
165+
parsed[name] = parser(value)
166+
# Instantiate the actual signature with the parsed values.
167+
# This allow pydantic to validate the fields defined in the signature.
168+
_dummy = self.signature(**kwargs, **parsed)
169+
parsed_results.append(parsed)
162170
except (pydantic.ValidationError, ValueError) as e:
163171
errors[name] = _format_error(e)
164172
# If we can, we add an example to the error message
@@ -168,11 +176,14 @@ def forward(self, **kwargs) -> dspy.Prediction:
168176
continue # Only add examples to JSON objects
169177
suffix, current_desc = current_desc[i:], current_desc[:i]
170178
prefix = "You MUST use this format: "
171-
if try_i + 1 < MAX_RETRIES \
172-
and prefix not in current_desc \
173-
and (example := self._make_example(field.annotation)):
179+
if (
180+
try_i + 1 < MAX_RETRIES
181+
and prefix not in current_desc
182+
and (example := self._make_example(field.annotation))
183+
):
174184
signature = signature.with_updated_fields(
175-
name, desc=current_desc + "\n" + prefix + example + "\n" + suffix,
185+
name,
186+
desc=current_desc + "\n" + prefix + example + "\n" + suffix,
176187
)
177188
if errors:
178189
# Add new fields for each error
@@ -187,11 +198,12 @@ def forward(self, **kwargs) -> dspy.Prediction:
187198
)
188199
else:
189200
# If there are no errors, we return the parsed results
190-
for name, value in parsed_results.items():
191-
setattr(result, name, value)
192-
return result
201+
return Prediction.from_completions(
202+
{key: [r[key] for r in parsed_results] for key in signature.output_fields}
203+
)
193204
raise ValueError(
194-
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.", errors,
205+
"Too many retries trying to get the correct output format. " + "Try simplifying the requirements.",
206+
errors,
195207
)
196208

197209

dspy/primitives/prediction.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -4,29 +4,29 @@
44
class Prediction(Example):
55
def __init__(self, *args, **kwargs):
66
super().__init__(*args, **kwargs)
7-
7+
88
del self._demos
99
del self._input_keys
1010

1111
self._completions = None
12-
12+
1313
@classmethod
1414
def from_completions(cls, list_or_dict, signature=None):
1515
obj = cls()
1616
obj._completions = Completions(list_or_dict, signature=signature)
1717
obj._store = {k: v[0] for k, v in obj._completions.items()}
1818

1919
return obj
20-
20+
2121
def __repr__(self):
22-
store_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._store.items())
22+
store_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._store.items())
2323

2424
if self._completions is None or len(self._completions) == 1:
2525
return f"Prediction(\n {store_repr}\n)"
26-
26+
2727
num_completions = len(self._completions)
2828
return f"Prediction(\n {store_repr},\n completions=Completions(...)\n) ({num_completions-1} completions omitted)"
29-
29+
3030
def __str__(self):
3131
return self.__repr__()
3232

@@ -62,15 +62,15 @@ def __getitem__(self, key):
6262
if isinstance(key, int):
6363
if key < 0 or key >= len(self):
6464
raise IndexError("Index out of range")
65-
65+
6666
return Prediction(**{k: v[key] for k, v in self._completions.items()})
67-
67+
6868
return self._completions[key]
6969

7070
def __getattr__(self, name):
7171
if name in self._completions:
7272
return self._completions[name]
73-
73+
7474
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
7575

7676
def __len__(self):
@@ -82,7 +82,7 @@ def __contains__(self, key):
8282
return key in self._completions
8383

8484
def __repr__(self):
85-
items_repr = ',\n '.join(f"{k}={repr(v)}" for k, v in self._completions.items())
85+
items_repr = ",\n ".join(f"{k}={repr(v)}" for k, v in self._completions.items())
8686
return f"Completions(\n {items_repr}\n)"
8787

8888
def __str__(self):

0 commit comments

Comments
 (0)