Skip to content

Commit 6412a5d

Browse files
committed
feat(gepa): add tool description optimization for multi-agent systems
- Add optimize_tool_descriptions parameter (default False) to GEPA - Extract tool descriptions from all nested modules via named_sub_modules() - Apply optimized descriptions in DspyAdapter.build_program() - Enables holistic optimization of tools across main and subagent modules - Tests: 4 new tests, all 16 pass (4 new + 12 existing)
1 parent cfb78f0 commit 6412a5d

File tree

3 files changed

+186
-2
lines changed

3 files changed

+186
-2
lines changed

dspy/teleprompt/gepa/gepa.py

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -273,6 +273,9 @@ def metric(
273273
warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when
274274
called with and without the pred_name. This flag (defaults to True) determines whether a warning is
275275
raised if a mismatch in module-level and predictor-level score is detected.
276+
optimize_tool_descriptions: Whether to optimize tool descriptions for modules with tools
277+
(e.g., ReAct agents). When enabled, tool descriptions are included in the optimization
278+
process alongside signature instructions. Default is False.
276279
seed: The random seed to use for reproducibility. Default is 0.
277280
gepa_kwargs: (Optional) provide additional kwargs to be passed to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py) method
278281
@@ -328,6 +331,7 @@ def __init__(
328331
wandb_init_kwargs: dict[str, Any] | None = None,
329332
track_best_outputs: bool = False,
330333
warn_on_score_mismatch: bool = True,
334+
optimize_tool_descriptions: bool = False,
331335
use_mlflow: bool = False,
332336
# Reproducibility
333337
seed: int | None = 0,
@@ -390,6 +394,7 @@ def __init__(
390394
self.wandb_api_key = wandb_api_key
391395
self.wandb_init_kwargs = wandb_init_kwargs
392396
self.warn_on_score_mismatch = warn_on_score_mismatch
397+
self.optimize_tool_descriptions = optimize_tool_descriptions
393398
self.use_mlflow = use_mlflow
394399

395400
if track_best_outputs:
@@ -518,11 +523,25 @@ def feedback_fn(
518523
rng=rng,
519524
reflection_lm=self.reflection_lm,
520525
custom_instruction_proposer=self.custom_instruction_proposer,
521-
warn_on_score_mismatch=self.warn_on_score_mismatch
526+
warn_on_score_mismatch=self.warn_on_score_mismatch,
527+
optimize_tool_descriptions=self.optimize_tool_descriptions
522528
)
523529

524530
# Instantiate GEPA with the simpler adapter-based API
525531
base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()}
532+
533+
if self.optimize_tool_descriptions:
534+
tool_descriptions = {}
535+
for _, module in student.named_sub_modules():
536+
if hasattr(module, 'tools'):
537+
for tool_name, tool in module.tools.items():
538+
tool_key = f"tool:{tool_name}"
539+
if tool_key not in tool_descriptions:
540+
tool_descriptions[tool_key] = tool.desc
541+
if tool_descriptions:
542+
logger.info(f"Including {len(tool_descriptions)} tool descriptions for optimization")
543+
base_program.update(tool_descriptions)
544+
526545
gepa_result: GEPAResult = optimize(
527546
seed_candidate=base_program,
528547
trainset=trainset,

dspy/teleprompt/gepa/gepa_utils.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,8 @@ def __init__(
7676
rng: random.Random | None = None,
7777
reflection_lm=None,
7878
custom_instruction_proposer: "ProposalFn | None" = None,
79-
warn_on_score_mismatch: bool = True
79+
warn_on_score_mismatch: bool = True,
80+
optimize_tool_descriptions: bool = False,
8081
):
8182
self.student = student_module
8283
self.metric_fn = metric_fn
@@ -88,6 +89,7 @@ def __init__(
8889
self.reflection_lm = reflection_lm
8990
self.custom_instruction_proposer = custom_instruction_proposer
9091
self.warn_on_score_mismatch = warn_on_score_mismatch
92+
self.optimize_tool_descriptions = optimize_tool_descriptions
9193

9294
if self.custom_instruction_proposer is not None:
9395
# We are only overriding the propose_new_texts method when a custom
@@ -124,6 +126,15 @@ def build_program(self, candidate: dict[str, str]):
124126
for name, pred in new_prog.named_predictors():
125127
if name in candidate:
126128
pred.signature = pred.signature.with_instructions(candidate[name])
129+
130+
if self.optimize_tool_descriptions:
131+
for _, module in new_prog.named_sub_modules():
132+
if hasattr(module, 'tools'):
133+
for tool_name, tool in module.tools.items():
134+
tool_key = f"tool:{tool_name}"
135+
if tool_key in candidate:
136+
tool.desc = candidate[tool_key]
137+
127138
return new_prog
128139

129140
def evaluate(self, batch, candidate, capture_traces=False):
Lines changed: 154 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,154 @@
1+
import dspy
2+
from dspy import Example
3+
from dspy.utils.dummies import DummyLM
4+
5+
6+
def calculator(expression: str) -> str:
7+
try:
8+
return str(eval(expression))
9+
except Exception:
10+
return "Error"
11+
12+
13+
def search(query: str) -> str:
14+
return f"Search results for: {query}"
15+
16+
17+
def simple_metric(example, prediction, trace=None, pred_name=None, pred_trace=None):
18+
score = 1.0 if example.answer in str(prediction.answer) else 0.0
19+
return dspy.Prediction(score=score, feedback="Correct" if score == 1.0 else "Wrong")
20+
21+
22+
def test_build_program_applies_tool_descriptions():
23+
"""Test that build_program applies tool descriptions from candidate dict."""
24+
from dspy.teleprompt.gepa.gepa_utils import DspyAdapter
25+
26+
calc_tool = dspy.Tool(calculator, name="calculator", desc="Old description")
27+
react = dspy.ReAct("question -> answer", tools=[calc_tool])
28+
29+
adapter = DspyAdapter(
30+
student_module=react,
31+
metric_fn=simple_metric,
32+
feedback_map={},
33+
failure_score=0.0,
34+
optimize_tool_descriptions=True,
35+
)
36+
37+
candidate = {
38+
"react": "New instruction for ReAct",
39+
"tool:calculator": "Optimized calculator description",
40+
}
41+
42+
new_prog = adapter.build_program(candidate)
43+
44+
assert new_prog.react.signature.instructions == "New instruction for ReAct"
45+
assert new_prog.tools["calculator"].desc == "Optimized calculator description"
46+
47+
48+
def test_gepa_with_tool_optimization_enabled():
49+
"""Test GEPA end-to-end with optimize_tool_descriptions=True."""
50+
calc_tool = dspy.Tool(calculator, name="calculator", desc="Does math")
51+
react = dspy.ReAct("question -> answer", tools=[calc_tool])
52+
53+
lm = DummyLM(
54+
[
55+
{"next_thought": "Calculate", "next_tool_name": "calculator", "next_tool_args": {"expression": "2+2"}},
56+
{"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}},
57+
{"reasoning": "Used calculator", "answer": "4"},
58+
]
59+
)
60+
reflection_lm = DummyLM([{"improved_instruction": "Better"}])
61+
62+
dspy.settings.configure(lm=lm)
63+
64+
optimizer = dspy.GEPA(
65+
metric=simple_metric,
66+
reflection_lm=reflection_lm,
67+
max_metric_calls=3,
68+
optimize_tool_descriptions=True,
69+
)
70+
71+
trainset = [Example(question="What is 2+2?", answer="4").with_inputs("question")]
72+
73+
optimized = optimizer.compile(react, trainset=trainset)
74+
75+
assert optimized is not None
76+
assert hasattr(optimized, "tools")
77+
assert "calculator" in optimized.tools
78+
79+
80+
def test_gepa_with_multi_agent_architecture():
81+
"""Test that tool optimization discovers tools from nested subagent modules."""
82+
class MultiAgentSystem(dspy.Module):
83+
def __init__(self):
84+
super().__init__()
85+
# Subagent as module attribute (reuse existing search function)
86+
search_tool = dspy.Tool(search, name="search", desc="Searches")
87+
self.subagent = dspy.ReAct("task -> result", tools=[search_tool])
88+
89+
# Main agent with subagent wrapped as tool
90+
def spawn_subagent(task: str) -> str:
91+
return self.subagent(task=task).result
92+
93+
spawn_tool = dspy.Tool(spawn_subagent, name="spawn_subagent", desc="Spawns subagent")
94+
calc_tool = dspy.Tool(calculator, name="calculator", desc="Does math")
95+
self.main_agent = dspy.ReAct("q -> a", tools=[spawn_tool, calc_tool])
96+
97+
system = MultiAgentSystem()
98+
99+
# Test extraction using named_sub_modules pattern
100+
tool_descriptions = {}
101+
for _, module in system.named_sub_modules():
102+
if hasattr(module, 'tools'):
103+
for tool_name, tool in module.tools.items():
104+
tool_key = f"tool:{tool_name}"
105+
if tool_key not in tool_descriptions:
106+
tool_descriptions[tool_key] = tool.desc
107+
108+
# All tools from all nested agents should be discovered
109+
assert "tool:calculator" in tool_descriptions
110+
assert "tool:spawn_subagent" in tool_descriptions
111+
assert "tool:search" in tool_descriptions
112+
assert "tool:finish" in tool_descriptions
113+
114+
115+
def test_gepa_optimizes_multi_agent_system_end_to_end():
116+
"""Test GEPA.compile() optimizes ALL tools from nested multi-agent system."""
117+
class MultiAgentSystem(dspy.Module):
118+
def __init__(self):
119+
super().__init__()
120+
search_tool = dspy.Tool(search, name="search", desc="Searches")
121+
self.subagent = dspy.ReAct("task -> result", tools=[search_tool])
122+
123+
def spawn_subagent(task: str) -> str:
124+
return self.subagent(task=task).result
125+
126+
spawn_tool = dspy.Tool(spawn_subagent, name="spawn_subagent", desc="Spawns subagent")
127+
calc_tool = dspy.Tool(calculator, name="calculator", desc="Does math")
128+
self.main_agent = dspy.ReAct("q -> a", tools=[spawn_tool, calc_tool])
129+
130+
def forward(self, question):
131+
return self.main_agent(q=question)
132+
133+
system = MultiAgentSystem()
134+
135+
# Setup LMs
136+
lm = DummyLM([{"q": "question", "a": "answer"}])
137+
reflection_lm = DummyLM([{"improved_instruction": "Better"}])
138+
dspy.settings.configure(lm=lm)
139+
140+
# Run GEPA optimization
141+
optimizer = dspy.GEPA(
142+
metric=simple_metric,
143+
reflection_lm=reflection_lm,
144+
max_metric_calls=3,
145+
optimize_tool_descriptions=True,
146+
)
147+
148+
trainset = [Example(question="test", answer="answer").with_inputs("question")]
149+
optimized = optimizer.compile(system, trainset=trainset)
150+
151+
# Verify optimized system preserves structure with all tools
152+
assert "search" in optimized.subagent.tools
153+
assert "calculator" in optimized.main_agent.tools
154+
assert "spawn_subagent" in optimized.main_agent.tools

0 commit comments

Comments
 (0)