Skip to content

Commit ecb3726

Browse files
committed
refactor(gepa): eliminate ReAct coupling in build_program
Replace ReAct-specific logic with generic approach: Before: - isinstance(ReAct) checks - Direct access to module.react/module.extract/module.tools - Separate if/elif branches for instruction updates After: - Program-level __dict__ traversal to find tools - Unified aggregation: plain strings → module config overrides - Single application loop (no duplication) Why __dict__ traversal: Tools can be declared as single attributes (self.tool), lists (self.tools=[...]), or dicts (self.tools={...}), and nested in any dspy.Module. Traversing __dict__ finds all tools regardless of how they're structured, without coupling to specific module types. This makes the code resilient to ReAct internal changes and works for any module using dspy.Tool.
1 parent e35603a commit ecb3726

File tree

1 file changed

+65
-69
lines changed

1 file changed

+65
-69
lines changed

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 65 additions & 69 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212
from dspy.adapters.types.base_type import Type
1313
from dspy.adapters.types.tool import Tool
1414
from dspy.evaluate import Evaluate
15-
from dspy.predict.react import ReAct
1615
from dspy.primitives import Example, Prediction
1716
from dspy.teleprompt.bootstrap_trace import TraceData
1817

@@ -106,9 +105,6 @@ def __init__(
106105

107106
self.propose_new_texts = self._build_propose_new_texts()
108107

109-
# Cache predictor names/signatures
110-
self.named_predictors = list(self.student.named_predictors())
111-
112108
def _build_propose_new_texts(self):
113109
"""Build proposal function that routes components to appropriate proposers."""
114110
# Init instruction proposer (custom or default)
@@ -202,79 +198,79 @@ def propose_component_texts(
202198
def build_program(self, candidate: dict[str, str]):
203199
new_prog = self.student.deepcopy()
204200

205-
# Apply regular predictor instructions
206-
for name, pred in new_prog.named_predictors():
207-
if name in candidate:
208-
pred.signature = pred.signature.with_instructions(candidate[name])
201+
# Start with plain string instructions from candidate
202+
improved_predictors = {
203+
k: v for k, v in candidate.items()
204+
if not k.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX))
205+
}
209206

210-
# Apply ReAct module updates (JSON configs for ReAct modules: react, extract, tools)
207+
improved_tools = {}
211208
if self.enable_tool_optimization:
212-
for _, module in new_prog.named_sub_modules():
213-
# Only process ReAct modules
214-
if not isinstance(module, ReAct):
209+
for key, value in candidate.items():
210+
if not key.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)):
215211
continue
216212

217-
# Find module key using extract predictor name
218-
extract_predictor = module.extract.predict
219-
module_key = None
213+
config = json.loads(value)
214+
215+
# Parse module configs and override predictor instructions
216+
for pred_name, instruction in config.items():
217+
if isinstance(instruction, str):
218+
improved_predictors[pred_name] = instruction
220219

221-
for name, pred in new_prog.named_predictors():
222-
if pred is extract_predictor:
223-
module_key = f"{REACT_MODULE_PREFIX}:{name}"
224-
break
220+
if "tools" in config:
221+
improved_tools.update(config["tools"])
225222

226-
# Check if this module was optimized
227-
if module_key is None or module_key not in candidate:
223+
# Update predictor instructions
224+
for name, pred in new_prog.named_predictors():
225+
if name in improved_predictors:
226+
pred.signature = pred.signature.with_instructions(improved_predictors[name])
227+
228+
# Update tool descriptions
229+
if improved_tools:
230+
def collect_tools(obj):
231+
all_tools = {}
232+
visited = set()
233+
234+
def traverse(o):
235+
if id(o) in visited or not hasattr(o, "__dict__"):
236+
return
237+
visited.add(id(o))
238+
239+
for attr_val in o.__dict__.values():
240+
if isinstance(attr_val, Tool):
241+
all_tools[attr_val.name] = attr_val
242+
elif isinstance(attr_val, list):
243+
for item in attr_val:
244+
if isinstance(item, Tool):
245+
all_tools[item.name] = item
246+
elif isinstance(attr_val, dict):
247+
for item in attr_val.values():
248+
if isinstance(item, Tool):
249+
all_tools[item.name] = item
250+
elif isinstance(attr_val, dspy.Module):
251+
traverse(attr_val)
252+
253+
traverse(obj)
254+
return all_tools
255+
256+
all_tools = collect_tools(new_prog)
257+
258+
for tool_name, tool_config in improved_tools.items():
259+
if tool_name not in all_tools:
228260
continue
229261

230-
# Deserialize JSON containing optimized module configuration
231-
try:
232-
module_config = json.loads(candidate[module_key])
233-
logger.debug(f"Applying optimized module config to {module_key}")
234-
235-
# Find predictor names for this module
236-
react_pred_name = None
237-
extract_pred_name = None
238-
for pred_name, pred in new_prog.named_predictors():
239-
if pred is module.react:
240-
react_pred_name = pred_name
241-
elif pred is module.extract.predict:
242-
extract_pred_name = pred_name
243-
244-
# Apply react instruction using actual predictor name as key
245-
if react_pred_name and react_pred_name in module_config:
246-
module.react.signature = module.react.signature.with_instructions(module_config[react_pred_name])
247-
logger.debug(" Updated react instruction")
248-
249-
# Apply extract instruction using actual predictor name as key
250-
if extract_pred_name and extract_pred_name in module_config:
251-
module.extract.predict.signature = module.extract.predict.signature.with_instructions(module_config[extract_pred_name])
252-
logger.debug(" Updated extract instruction")
253-
254-
# Apply tool descriptions
255-
if "tools" in module_config:
256-
for tool_name, tool_config in module_config["tools"].items():
257-
tool = module.tools[tool_name]
258-
259-
# Update tool description
260-
if tool_config.get("desc"):
261-
tool.desc = tool_config["desc"]
262-
logger.debug(f" Updated tool '{tool_name}' description")
263-
264-
# Update tool arg descriptions
265-
arg_desc = tool_config.get("arg_desc")
266-
if arg_desc:
267-
tool.arg_desc = tool.arg_desc or {}
268-
tool.arg_desc.update(arg_desc)
269-
# Propagate to tool.args
270-
for arg_name, description in arg_desc.items():
271-
if arg_name in tool.args:
272-
tool.args[arg_name]["description"] = description
273-
logger.debug(f" Updated tool '{tool_name}' arg descriptions: {list(arg_desc.keys())}")
274-
275-
except json.JSONDecodeError as e:
276-
logger.error(f"Failed to parse JSON config for {module_key}: {e}")
277-
raise
262+
tool = all_tools[tool_name]
263+
264+
if tool_config.get("desc"):
265+
tool.desc = tool_config["desc"]
266+
267+
arg_desc = tool_config.get("arg_desc")
268+
if arg_desc:
269+
tool.arg_desc = tool.arg_desc or {}
270+
tool.arg_desc.update(arg_desc)
271+
for arg_name, description in arg_desc.items():
272+
if arg_name in tool.args:
273+
tool.args[arg_name]["description"] = description
278274

279275
return new_prog
280276

0 commit comments

Comments
 (0)