Skip to content

Commit c7f8cc0

Browse files
authored
Merge branch 'stanfordnlp:main' into main
2 parents b88693d + 09f9884 commit c7f8cc0

File tree

4 files changed

+38
-7
lines changed

4 files changed

+38
-7
lines changed

dspy/primitives/module.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,17 +4,48 @@
44

55
import ujson
66

7+
# NOTE: Note: It's important (temporary decision) to maintain named_parameters that's different in behavior from
8+
# named_sub_modules for the time being.
79

810
class BaseModule:
911
def __init__(self):
1012
pass
1113

1214
def named_parameters(self):
13-
"""Unlike PyTorch, handles lists of parameters too."""
15+
"""
16+
Unlike PyTorch, handles (non-recursive) lists of parameters too.
17+
"""
18+
19+
import dspy
1420
from dspy.predict.parameter import Parameter
1521

16-
# Remove the 'self.' prefix from the names
17-
return [(name[5:], param) for name, param in self.named_sub_modules(Parameter)]
22+
visited = set()
23+
named_parameters = []
24+
25+
def add_parameter(param_name, param_value):
26+
if isinstance(param_value, Parameter) and id(param_value) not in visited:
27+
visited.add(id(param_value))
28+
named_parameters.append((param_name, param_value))
29+
30+
for name, value in self.__dict__.items():
31+
if isinstance(value, Parameter):
32+
add_parameter(name, value)
33+
34+
elif isinstance(value, dspy.Module):
35+
# When a sub-module is pre-compiled, keep it frozen.
36+
if not getattr(value, "_compiled", False):
37+
for sub_name, param in value.named_parameters():
38+
add_parameter(f"{name}.{sub_name}", param)
39+
40+
elif isinstance(value, (list, tuple)):
41+
for idx, item in enumerate(value):
42+
add_parameter(f"{name}[{idx}]", item)
43+
44+
elif isinstance(value, dict):
45+
for key, item in value.items():
46+
add_parameter(f"{name}['{key}']", item)
47+
48+
return named_parameters
1849

1950
def named_sub_modules(self, type_=None, skip_compiled=False) -> Generator[tuple[str, "BaseModule"], None, None]:
2051
"""Find all sub-modules in the module, as well as their names.

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
44

55
[project]
66
name = "dspy-ai"
7-
version = "2.4.3"
7+
version = "2.4.9"
88
description = "DSPy"
99
readme = "README.md"
1010
authors = [{ name = "Omar Khattab", email = "okhattab@stanford.edu" }]

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010

1111
setup(
1212
name="dspy-ai",
13-
version="2.4.3",
13+
version="2.4.9",
1414
description="DSPy",
1515
long_description=long_description,
1616
long_description_content_type='text/markdown',

tests/primitives/test_program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -126,9 +126,9 @@ def test_complex_module_traversal():
126126
"self",
127127
"self.sub_module",
128128
"self.sub_module.nested_list[0]",
129-
"self.sub_module.nested_list[1][key]",
129+
"self.sub_module.nested_list[1][key]", # NOTE: named_sub_modules allows recursive structures
130130
"self.sub_module.nested_tuple[0]",
131-
"self.sub_module.nested_tuple[1][0]",
131+
"self.sub_module.nested_tuple[1][0]", # NEW: named_sub_modules allows recursive structures, but named_prameters does not
132132
# "self.sub_module.nested_tuple[1][1]", This should not be included, as it's the same module as the previous one
133133
}
134134
found_names = {name for name, _ in root.named_sub_modules()}

0 commit comments

Comments
 (0)