|
4 | 4 |
|
5 | 5 | import ujson |
6 | 6 |
|
| 7 | +# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from |
| 8 | +# named_sub_modules for the time being. |
7 | 9 |
|
8 | 10 | class BaseModule: |
9 | 11 | def __init__(self): |
10 | 12 | pass |
11 | 13 |
|
12 | 14 | def named_parameters(self): |
13 | | - """Unlike PyTorch, handles lists of parameters too.""" |
| 15 | + """ |
| 16 | + Unlike PyTorch, handles (non-recursive) lists of parameters too. |
| 17 | + """ |
| 18 | + |
| 19 | + import dspy |
14 | 20 | from dspy.predict.parameter import Parameter |
15 | 21 |
|
16 | | - # Remove the 'self.' prefix from the names |
17 | | - return [(name[5:], param) for name, param in self.named_sub_modules(Parameter)] |
| 22 | + visited = set() |
| 23 | + named_parameters = [] |
| 24 | + |
| 25 | + def add_parameter(param_name, param_value): |
| 26 | + if isinstance(param_value, Parameter) and id(param_value) not in visited: |
| 27 | + visited.add(id(param_value)) |
| 28 | + named_parameters.append((param_name, param_value)) |
| 29 | + |
| 30 | + for name, value in self.__dict__.items(): |
| 31 | + if isinstance(value, Parameter): |
| 32 | + add_parameter(name, value) |
| 33 | + |
| 34 | + elif isinstance(value, dspy.Module): |
| 35 | + # When a sub-module is pre-compiled, keep it frozen. |
| 36 | + if not getattr(value, "_compiled", False): |
| 37 | + for sub_name, param in value.named_parameters(): |
| 38 | + add_parameter(f"{name}.{sub_name}", param) |
| 39 | + |
| 40 | + elif isinstance(value, (list, tuple)): |
| 41 | + for idx, item in enumerate(value): |
| 42 | + add_parameter(f"{name}[{idx}]", item) |
| 43 | + |
| 44 | + elif isinstance(value, dict): |
| 45 | + for key, item in value.items(): |
| 46 | + add_parameter(f"{name}['{key}']", item) |
| 47 | + |
| 48 | + return named_parameters |
18 | 49 |
|
19 | 50 | def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]: |
20 | 51 | """Find all sub-modules in the module, as well as their names. |
|
0 commit comments