Skip to content

Commit 3132a7f

Browse files
authored
Merge pull request #783 from limnick/main
allow passing of kwargs to typed predictor decorators
2 parents 38c26cc + b63b7af commit 3132a7f

File tree

1 file changed

+26
-12
lines changed

1 file changed

+26
-12
lines changed

dspy/functional/functional.py

Lines changed: 26 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -14,18 +14,32 @@
1414
from dspy.signatures.signature import ensure_signature, make_signature
1515

1616

17-
def predictor(func) -> dspy.Module:
18-
"""Decorator that creates a predictor module based on the provided function."""
19-
signature = _func_to_signature(func)
20-
*_, output_key = signature.output_fields.keys()
21-
return _StripOutput(TypedPredictor(signature), output_key)
22-
23-
24-
def cot(func) -> dspy.Module:
25-
"""Decorator that creates a chain of thought module based on the provided function."""
26-
signature = _func_to_signature(func)
27-
*_, output_key = signature.output_fields.keys()
28-
return _StripOutput(TypedChainOfThought(signature), output_key)
17+
def predictor(*args, **kwargs):
18+
def _predictor(func) -> dspy.Module:
19+
"""Decorator that creates a predictor module based on the provided function."""
20+
signature = _func_to_signature(func)
21+
*_, output_key = signature.output_fields.keys()
22+
return _StripOutput(TypedPredictor(signature, **kwargs), output_key)
23+
24+
# if we have only a single callable argument, the decorator was invoked with no key word arguments
25+
# so we just return the wrapped function
26+
if len(args) == 1 and callable(args[0]) and len(kwargs) == 0:
27+
return _predictor(args[0])
28+
return _predictor
29+
30+
31+
def cot(*args, **kwargs):
32+
def _cot(func) -> dspy.Module:
33+
"""Decorator that creates a chain of thought module based on the provided function."""
34+
signature = _func_to_signature(func)
35+
*_, output_key = signature.output_fields.keys()
36+
return _StripOutput(TypedChainOfThought(signature, **kwargs), output_key)
37+
38+
# if we have only a single callable argument, the decorator was invoked with no key word arguments
39+
# so we just return the wrapped function
40+
if len(args) == 1 and callable(args[0]) and len(kwargs) == 0:
41+
return _cot(args[0])
42+
return _cot
2943

3044

3145
class _StripOutput(dspy.Module):

0 commit comments

Comments
 (0)