|
14 | 14 | from dspy.signatures.signature import ensure_signature, make_signature |
15 | 15 |
|
16 | 16 |
|
17 | | -def predictor(func) -> dspy.Module: |
18 | | - """Decorator that creates a predictor module based on the provided function.""" |
19 | | - signature = _func_to_signature(func) |
20 | | - *_, output_key = signature.output_fields.keys() |
21 | | - return _StripOutput(TypedPredictor(signature), output_key) |
22 | | - |
23 | | - |
24 | | -def cot(func) -> dspy.Module: |
25 | | - """Decorator that creates a chain of thought module based on the provided function.""" |
26 | | - signature = _func_to_signature(func) |
27 | | - *_, output_key = signature.output_fields.keys() |
28 | | - return _StripOutput(TypedChainOfThought(signature), output_key) |
| 17 | +def predictor(*args, **kwargs): |
| 18 | + def _predictor(func) -> dspy.Module: |
| 19 | + """Decorator that creates a predictor module based on the provided function.""" |
| 20 | + signature = _func_to_signature(func) |
| 21 | + *_, output_key = signature.output_fields.keys() |
| 22 | + return _StripOutput(TypedPredictor(signature, **kwargs), output_key) |
| 23 | + |
| 24 | + # if we have only a single callable argument, the decorator was invoked with no key word arguments |
| 25 | + # so we just return the wrapped function |
| 26 | + if len(args) == 1 and callable(args[0]) and len(kwargs) == 0: |
| 27 | + return _predictor(args[0]) |
| 28 | + return _predictor |
| 29 | + |
| 30 | + |
| 31 | +def cot(*args, **kwargs): |
| 32 | + def _cot(func) -> dspy.Module: |
| 33 | + """Decorator that creates a chain of thought module based on the provided function.""" |
| 34 | + signature = _func_to_signature(func) |
| 35 | + *_, output_key = signature.output_fields.keys() |
| 36 | + return _StripOutput(TypedChainOfThought(signature, **kwargs), output_key) |
| 37 | + |
| 38 | + # if we have only a single callable argument, the decorator was invoked with no key word arguments |
| 39 | + # so we just return the wrapped function |
| 40 | + if len(args) == 1 and callable(args[0]) and len(kwargs) == 0: |
| 41 | + return _cot(args[0]) |
| 42 | + return _cot |
29 | 43 |
|
30 | 44 |
|
31 | 45 | class _StripOutput(dspy.Module): |
|
0 commit comments