diff --git a/.gitignore b/.gitignore index 14f48d7f31..6d2756500a 100644 --- a/.gitignore +++ b/.gitignore @@ -67,3 +67,8 @@ docs/docs/**/*.json* test_before_pypi/ .github/.internal_dspyai/dist/ + +AGENTS.md +/tasks +/notebooks +.cache/ diff --git a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md index 624e580ad1..8ac2c31c42 100644 --- a/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md +++ b/docs/docs/api/optimizers/GEPA/GEPA_Advanced.md @@ -443,3 +443,373 @@ gepa = dspy.GEPA( auto="medium" ) ``` + +## ReAct Component Optimization + +### What is optimize_react_components? + +Enable `optimize_react_components=True` to apply specialized optimization to `dspy.ReAct` modules while using default optimization for other modules. + +A [`dspy.ReAct`](../../learn/programming/tools.md#approach-1-using-dspyreact-fully-managed) module has three parts: a **react predictor** (iteratively reasons and selects tools), an **extract predictor** (extracts final answers from trajectories), and **tools** with their schemas. + +**What gets optimized for ReAct modules:** + +GEPA can improve textual components across all parts: +- **React instruction** - Guides reasoning and tool selection (always optimized) +- **Extract instruction** - Guides answer extraction from trajectories (optional) +- **Tool descriptions** - Describes what each tool does (optional) +- **Tool argument descriptions** - Describes tool parameters (optional) + +The reflection LM decides which optional components to improve based on observed failures. Non-ReAct modules in your program are optimized using GEPA's default signature optimization. + +**Why this matters:** + +Unlike optimizing signature instructions alone (which improves individual predictors), ReAct optimization improves the **entire agent workflow** - from initial reasoning through tool execution to final answer extraction. + +ReAct agents often fail when their components contradict each other. A clear tool description doesn't help if the react instruction never considers using that tool. GEPA analyzes execution traces to learn how all components should work together. + +### ReAct Optimization Prompt + +GEPA uses a specialized prompt to jointly optimize all ReAct components. The prompt receives complete ReAct trajectories and current component texts: + +```python +class GenerateImprovedReActDescriptionsFromFeedback(dspy.Signature): + """Improve a ReAct agent based on execution examples and feedback. + + These components are progressively optimized - refine what needs improvement. + Analyze the trajectories to identify successful patterns and failure causes. + Generate improved texts to help the agent succeed on similar tasks. + Place improved texts at their appropriate level of abstraction and/or specificity. + """ + + current_react_instruction = dspy.InputField( + desc="Current ReAct module instruction guiding the ReAct agent's reasoning and tool selection" + ) + current_extract_instruction = dspy.InputField( + desc="Current Extract module instruction for extracting final answers from trajectories" + ) + current_tools = dspy.InputField( + annotation=list[dspy.Tool], + desc="Available tools with their complete schemas" + ) + examples_with_feedback = dspy.InputField( + desc="Execution examples with feedback showing successes and failures" + ) + + improved_react_instruction: str | None = dspy.OutputField( + desc="ReAct instruction for reasoning and tool selection", + default=None + ) + improved_extract_instruction: str | None = dspy.OutputField( + desc="Extract instruction for answer extraction", + default=None + ) + # Note: Tool descriptions and arg descriptions are added dynamically via signature.append() + # with field descriptions like "Purpose of tool" and "Usage of parameter" +``` + +The reflection LM receives all current components and execution traces, then decides which components to improve. Tool-specific fields (`improved_tool_{name}_desc`, `improved_tool_{name}_arg_{param}_desc`) are generated dynamically for each tool and parameter. + +**Writing Metrics for ReAct Optimization** + +GEPA optimizes ReAct modules more effectively when metrics provide feedback about the agent's execution. Here's how to write metrics that help: + +```python +def react_metric(example, pred, trace=None, pred_name=None, pred_trace=None): + """Evaluate ReAct agent performance with trajectory feedback.""" + # Check if the answer is correct + answer_match = pred.answer == example.answer + score = 1.0 if answer_match else 0.0 + + # Provide feedback to help GEPA understand what happened + feedback = "Correct answer" if answer_match else "Incorrect answer" + + return dspy.Prediction(score=score, feedback=feedback) +``` + +You can make feedback more informative by examining the trajectory: + +```python +def react_metric_with_trajectory(example, pred, trace=None, pred_name=None, pred_trace=None): + """Evaluate with trajectory analysis.""" + # Check if the answer is correct + answer_match = pred.answer == example.answer + score = 1.0 if answer_match else 0.0 + + # Access the ReAct trajectory to understand agent behavior + trajectory = getattr(pred, 'trajectory', {}) + + # Extract tool names from trajectory (excluding 'finish') + tools_used = [] + for key in trajectory: + if key.startswith('tool_name_'): + tool_name = trajectory[key] + if tool_name != 'finish': + tools_used.append(tool_name) + + # Build feedback message + if answer_match: + feedback = "Correct answer" + else: + feedback = "Incorrect answer" + + if tools_used: + feedback += f". Tools: {', '.join(tools_used)}" + + return dspy.Prediction(score=score, feedback=feedback) +``` + +The trajectory contains the agent's step-by-step execution. Use it to provide feedback about: + +- **Tool selection**: Were appropriate tools chosen? +- **Reasoning quality**: Did the agent think through the problem? +- **Efficiency**: Were there unnecessary steps? + +The reflection LM uses your feedback to jointly improve react instructions, tool descriptions, and extraction logic. + +### How It Works + +When `optimize_react_components=True`, GEPA: + +1. **Discovers ReAct modules** - Finds all `dspy.ReAct` instances in your program (including nested modules) +2. **Extracts components** - Collects react instructions, extract instructions, and tool schemas from each ReAct module +3. **Routes to proposers** - Separates components by type and routes them appropriately: + - **With custom `instruction_proposer`**: Your custom proposer receives all components (both regular instructions and ReAct components) and handles the optimization logic + - **With default proposer**: Regular instructions use default instruction proposer, ReAct components use specialized `ReActModuleProposer` +4. **Optimizes jointly** - ReAct proposer improves all four components together based on execution feedback +5. **Applies updates** - Updates your ReAct modules with improved instructions and tool descriptions + +Non-ReAct modules (like `dspy.Predict` or `dspy.ChainOfThought`) continue using standard GEPA optimization. + +### When to Use optimize_react_components + +Enable `optimize_react_components=True` when you use `dspy.ReAct` in your program and need better agent performance. GEPA jointly optimizes all ReAct components (react instruction, extract instruction, tool descriptions, tool argument descriptions) based on execution feedback. Common scenarios: + +1. **Agent loops with repeated tool calls** - Agent keeps calling `web_search` multiple times with similar queries instead of synthesizing information. GEPA improves react instruction to encourage synthesis and tool descriptions to clarify when searches are sufficient. + +2. **Wrong tool selection** - Agent with `search` and `calculator` tools keeps searching when it should calculate, or vice versa. GEPA refines react instruction and tool descriptions to clarify "use search for factual queries, calculator for numerical analysis." + +3. **Agent gives up without trying tools** - Agent responds "I don't know" without using available tools that could answer the question. GEPA improves react instruction to be more proactive about tool usage. + +4. **Extraction failures** - Agent executes tools correctly but fails to extract the final answer from the trajectory. GEPA improves extract instruction to better identify and format answers from tool outputs. + +5. **Multi-agent delegation issues** - Parent agent has delegation tools to specialized sub-agents but doesn't understand when to use each. GEPA optimizes all ReAct components across both parent and sub-agent modules for coherent delegation. + +See the usage examples below for basic ReAct agents and multi-agent systems. + +### Usage Examples + +#### Basic ReAct Agent + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def calculate(expression: str) -> float: + return eval(expression) + +# Create ReAct agent with tools (poor initial descriptions) +search_tool = dspy.Tool(search_web, name="search", desc="Finds things") +calc_tool = dspy.Tool(calculate, name="calculator", desc="Does calculations") + +agent = dspy.ReAct("question -> answer", tools=[search_tool, calc_tool]) + +# Enable tool optimization +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + optimize_react_components=True, + component_selector="all", # Optimize all components together + auto="medium" +) + +optimized_agent = gepa.compile(agent, trainset=train_examples, valset=val_examples) + +# View optimized tool descriptions +print("Optimized search tool:", optimized_agent.tools["search"].desc) +print("Optimized calculator tool:", optimized_agent.tools["calculator"].desc) +``` + +**Example output after optimization:** +``` +Optimized search tool: Use when you need to find current information, facts, or data + from external sources. Provide specific search queries to get relevant results. + +Optimized calculator tool: Use for arithmetic operations and mathematical expressions. + Accepts Python-compatible expressions with numbers and operators (+, -, *, /, **). + Do not use for date calculations or string manipulations. +``` + +#### Multi-Agent System + +GEPA automatically discovers and optimizes tools in nested agents: + +```python +import dspy + +def search_web(query: str) -> str: + return f"Search results for: {query}" + +def calculate(expression: str) -> float: + return eval(expression) + +search_tool = dspy.Tool(search_web, name="search", desc="Searches") +calc_tool = dspy.Tool(calculate, name="calculator", desc="Computes") + +class ResearchAssistant(dspy.Module): + def __init__(self): + super().__init__() + self.researcher = dspy.ReAct("query -> findings", tools=[search_tool]) + + def delegate_research(query: str) -> str: + return self.researcher(query=query).findings + + research_tool = dspy.Tool(delegate_research, name="research", desc="Helps with questions") + self.assistant = dspy.ReAct("question -> answer", tools=[research_tool, calc_tool]) + + def forward(self, question): + return self.assistant(question=question) + +# Optimizes ALL tools: calculator, research, search +gepa = dspy.GEPA( + metric=my_metric, + reflection_lm=dspy.LM(model="gpt-5-mini"), + optimize_react_components=True, + component_selector="all", + auto="medium" +) + +optimized_system = gepa.compile(ResearchAssistant(), trainset=train, valset=val) + +# View optimized nested tool descriptions +print(optimized_system.researcher.tools["search"].desc) +print(optimized_system.assistant.tools["research"].desc) +print(optimized_system.assistant.tools["calculator"].desc) +``` + +### Inspecting Optimized ReAct Components + +After optimization, all ReAct components are automatically updated in your program. Access them directly: + +```python +optimized_agent = gepa.compile(agent, trainset=train, valset=val) + +# ReAct instruction (guides reasoning and tool selection) +print("React instruction:", optimized_agent.react.signature.instructions) + +# Extract instruction (guides answer extraction from trajectory) +print("Extract instruction:", optimized_agent.extract.predict.signature.instructions) + +# Tool descriptions +for tool_name, tool in optimized_agent.tools.items(): + if tool_name != 'finish': # Skip the built-in finish tool + print(f"Tool '{tool_name}' description:", tool.desc) + # Tool argument descriptions + print(f" Argument descriptions:", tool.arg_desc) +``` + +### Custom Instruction Proposers and ReAct Optimization + +**Important:** When you provide a custom `instruction_proposer`, it receives ALL components (regular predictors AND ReAct modules). You must set `optimize_react_components=True` to enable ReAct module discovery and serialization, then handle the optimization logic yourself. + +**How it works internally:** + +1. **Component Discovery** - GEPA discovers components in your program: + - Regular predictors → keys like `"predict"`, `"chain_of_thought"` + - ReAct modules → keys like `"react_module"` or `"react_module:agent_name"` + +2. **ReAct Serialization** - When `optimize_react_components=True`, GEPA serializes ReAct modules as JSON: + ```json + { + "react": "instruction for reasoning and tool selection", + "extract": "instruction for answer extraction", + "tools": { + "tool_name": { + "desc": "what the tool does", + "args": {"param": {"type": "string"}}, + "arg_desc": {"param": "description of param"} + } + } + } + ``` + +3. **Custom Proposer Receives**: + - `candidate: dict[str, str]` - **All values are strings** + - Regular component: `candidate["predict"]` → `"Your instruction here"` + - ReAct component: `candidate["react_module"]` → `'{"react": "...", "extract": "...", "tools": {...}}'` (JSON as a string) + - `reflective_dataset: dict[str, list[ReflectiveExample]]` - **GEPA provides this** + - Contains execution traces: inputs, outputs (including full ReAct trajectory), and your metric's feedback + - For ReAct: `Generated_Outputs` includes the entire trajectory with all tool calls and reasoning + - Use this to understand what went wrong and guide your improvements + - `components_to_update: list[str]` - Component keys to optimize this round + +4. **Your Responsibility**: + - For ReAct components: Use `json.loads()` to parse, improve all 4 parts, use `json.dumps()` to return + - For regular components: Improve the instruction string directly + - Return `dict[str, str]` with same keys + +**What this means:** +- Your custom proposer receives ALL components: regular signatures AND ReAct modules +- GEPA still does discovery and JSON serialization, but YOU handle the optimization logic +- ReAct components are passed with keys like `"react_module"` or `"react_module:agent_name"` + +#### Implementing a Custom Proposer for ReAct + +If you need custom optimization logic beyond the default, you can build your own proposer. The best way to start is by looking at the reference implementation: [`ReActModuleProposer`](https://github.com/stanfordnlp/dspy/blob/main/dspy/teleprompt/gepa/instruction_proposal.py). + +**Understanding ReAct component structure** + +When GEPA optimizes ReAct modules, it serializes them as JSON strings containing all the pieces you can improve: + +```json +{ + "react": "instruction for reasoning and tool selection", + "extract": "instruction for answer extraction", + "tools": { + "search": { + "desc": "Search the web for information", + "args": {"query": {"type": "string"}}, + "arg_desc": {"query": "The search query to execute"} + } + } +} +``` + +**What you can improve:** +- **`react`** - How the agent reasons and decides which tools to use +- **`extract`** - How the agent extracts the final answer from execution results +- **`tools[*].desc`** - What each tool does and when to use it +- **`tools[*].arg_desc`** - What each parameter means and how to use it + +**What to preserve:** +- **`tools[*].args`** - The tool's parameter schema (types, required fields, etc.) + +**Your proposer's interface** + +Your custom proposer is a callable that receives component instructions and execution feedback, then returns improved versions: + +```python +def your_custom_proposer( + candidate: dict[str, str], # Current instructions for all components + reflective_dataset: dict[str, list], # Execution examples with feedback + components_to_update: list[str], # Which components to optimize this round +) -> dict[str, str]: # Return improved instructions + """ + For ReAct components: + - Use json.loads() to parse the JSON string + - Improve what needs fixing based on the feedback + - Use json.dumps() to serialize back + + For regular components: + - Just return the improved instruction string + """ + # Your optimization logic here + pass +``` + +**The reference shows how to:** +- Parse and rebuild the JSON structure +- Generate dynamic fields for tools/parameters +- Use execution feedback to guide improvements diff --git a/docs/docs/api/optimizers/GEPA/overview.md b/docs/docs/api/optimizers/GEPA/overview.md index 0125702bea..c36065b6aa 100644 --- a/docs/docs/api/optimizers/GEPA/overview.md +++ b/docs/docs/api/optimizers/GEPA/overview.md @@ -117,6 +117,12 @@ Practical Recipe for GEPA-Friendly Feedback: - **Multi-Objective Tasks** (e.g., PUPA): Decompose aggregate scores to reveal contributions from each objective, highlighting tradeoffs (e.g., quality vs. privacy). - **Stacked Pipelines** (e.g., code generation: parse → compile → run → profile → evaluate): Expose stage-specific failures; natural-language traces often suffice for LLM self-correction. +## ReAct Component Optimization + +GEPA can optimize ReAct modules holistically. When `optimize_react_components=True`, GEPA jointly optimizes all four components of ReAct modules: react instructions, extract instructions, tool descriptions, and tool argument descriptions. This helps agents make better decisions by learning from execution traces how all components work together. + +For details on how ReAct optimization works, when to use it, and usage examples, see [ReAct Component Optimization](GEPA_Advanced.md#react-component-optimization) in the Advanced Features guide. + ## Custom Instruction Proposal For advanced customization of GEPA's instruction proposal mechanism, including custom instruction proposers and component selectors, see [Advanced Features](GEPA_Advanced.md). diff --git a/dspy/teleprompt/gepa/gepa.py b/dspy/teleprompt/gepa/gepa.py index c35e916691..9888f82513 100644 --- a/dspy/teleprompt/gepa/gepa.py +++ b/dspy/teleprompt/gepa/gepa.py @@ -1,16 +1,26 @@ import inspect +import json import logging import random from dataclasses import dataclass -from typing import Any, Literal, Optional, Protocol, Union +from typing import Any, Literal, Optional, Protocol, Union, get_args, get_origin from gepa import GEPAResult from gepa.core.adapter import ProposalFn from gepa.proposer.reflective_mutation.base import ReflectionComponentSelector +from dspy.adapters.types.tool import Tool from dspy.clients.lm import LM +from dspy.predict.react import ReAct from dspy.primitives import Example, Module, Prediction -from dspy.teleprompt.gepa.gepa_utils import DspyAdapter, DSPyTrace, PredictorFeedbackFn, ScoreWithFeedback +from dspy.teleprompt.gepa.gepa_utils import ( + REACT_MODULE_PREFIX, + TOOL_MODULE_PREFIX, + DspyAdapter, + DSPyTrace, + PredictorFeedbackFn, + ScoreWithFeedback, +) from dspy.teleprompt.teleprompt import Teleprompter from dspy.utils.annotation import experimental @@ -36,18 +46,18 @@ def __call__( - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which + - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. - If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding + If available at the predictor level, the metric should return dspy.Prediction(score: float, feedback: str) corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). - If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: + If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." """ ... @@ -172,18 +182,18 @@ def metric( - gold: The gold example. - pred: The predicted output. - trace: Optional. The trace of the program's execution. - - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which + - pred_name: Optional. The name of the target predictor currently being optimized by GEPA, for which the feedback is being requested. - pred_trace: Optional. The trace of the target predictor's execution GEPA is seeking feedback for. Note the `pred_name` and `pred_trace` arguments. During optimization, GEPA will call the metric to obtain feedback for individual predictors being optimized. GEPA provides the name of the predictor in `pred_name` and the sub-trace (of the trace) corresponding to the predictor in `pred_trace`. - If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding + If available at the predictor level, the metric should return {'score': float, 'feedback': str} corresponding to the predictor. If not available at the predictor level, the metric can also return a text feedback at the program level (using just the gold, pred and trace). - If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: + If no feedback is returned, GEPA will use a simple text feedback consisting of just the score: f"This trajectory got a score of {score}." \""" ... @@ -207,41 +217,41 @@ def metric( max_full_evals: The maximum number of full evaluations to perform. max_metric_calls: The maximum number of metric calls to perform. reflection_minibatch_size: The number of examples to use for reflection in a single GEPA step. Default is 3. - candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", - which stochastically selects candidates from the Pareto frontier of all validation scores. + candidate_selection_strategy: The strategy to use for candidate selection. Default is "pareto", + which stochastically selects candidates from the Pareto frontier of all validation scores. Options: "pareto", "current_best". - reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from - a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` + reflection_lm: The language model to use for reflection. Required parameter. GEPA benefits from + a strong reflection model. Consider using `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)` for optimal performance. skip_perfect_score: Whether to skip examples with perfect scores during reflection. Default is True. instruction_proposer: Optional custom instruction proposer implementing GEPA's ProposalFn protocol. - **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from - the [GEPA library](https://github.com/gepa-ai/gepa), which implements the - [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default - proposer is highly capable and was validated across diverse experiments reported in the GEPA + **Default: None (recommended for most users)** - Uses GEPA's proven instruction proposer from + the [GEPA library](https://github.com/gepa-ai/gepa), which implements the + [`ProposalFn`](https://github.com/gepa-ai/gepa/blob/main/src/gepa/core/adapter.py). This default + proposer is highly capable and was validated across diverse experiments reported in the GEPA paper and tutorials. - See documentation on custom instruction proposers + See documentation on custom instruction proposers [here](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#custom-instruction-proposers). - + **Advanced Feature**: Only needed for specialized scenarios: - **Multi-modal handling**: Processing dspy.Image inputs alongside textual information - - **Nuanced control over constraints**: Fine-grained control over instruction length, format, + - **Nuanced control over constraints**: Fine-grained control over instruction length, format, and structural requirements beyond standard feedback mechanisms - - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be + - **Domain-specific knowledge injection**: Specialized terminology or context that cannot be provided through feedback_func alone - - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) + - **Provider-specific prompting**: Optimizations for specific LLM providers (OpenAI, Anthropic) with unique formatting preferences - - **Coupled component updates**: Coordinated updates of multiple components together rather + - **Coupled component updates**: Coordinated updates of multiple components together rather than independent optimization - **External knowledge integration**: Runtime access to databases, APIs, or knowledge bases - - The default proposer handles the vast majority of use cases effectively. Use - MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual + + The default proposer handles the vast majority of use cases effectively. Use + MultiModalInstructionProposer() from dspy.teleprompt.gepa.instruction_proposal for visual content or implement custom ProposalFn for highly specialized requirements. - - Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called - in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. + + Note: When both instruction_proposer and reflection_lm are set, the instruction_proposer is called + in the reflection_lm context. However, reflection_lm is optional when using a custom instruction_proposer. Custom instruction proposers can invoke their own LLMs if needed. component_selector: Custom component selector implementing the [ReflectionComponentSelector](https://github.com/gepa-ai/gepa/blob/main/src/gepa/proposer/reflective_mutation/base.py) protocol, or a string specifying a built-in selector strategy. Controls which components (predictors) are selected @@ -256,23 +266,28 @@ def metric( max_merge_invocations: The maximum number of merge invocations to perform. Default is 5. num_threads: The number of threads to use for evaluation with `Evaluate`. Optional. failure_score: The score to assign to failed examples. Default is 0.0. - perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA + perfect_score: The maximum score achievable by the metric. Default is 1.0. Used by GEPA to determine if all examples in a minibatch are perfect. - log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate - programs, in this directory. Running GEPA with the same `log_dir` will resume the run + log_dir: The directory to save the logs. GEPA saves elaborate logs, along with all candidate + programs, in this directory. Running GEPA with the same `log_dir` will resume the run from the last checkpoint. - track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` + track_stats: Whether to return detailed results and all proposed programs in the `detailed_results` attribute of the optimized program. Default is False. use_wandb: Whether to use wandb for logging. Default is False. - wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key + wandb_api_key: The API key to use for wandb. If not provided, wandb will use the API key from the environment variable `WANDB_API_KEY`. wandb_init_kwargs: Additional keyword arguments to pass to `wandb.init`. - track_best_outputs: Whether to track the best outputs on the validation set. track_stats must - be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` + track_best_outputs: Whether to track the best outputs on the validation set. track_stats must + be True if track_best_outputs is True. The optimized program's `detailed_results.best_outputs_valset` will contain the best outputs for each task in the validation set. - warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when - called with and without the pred_name. This flag (defaults to True) determines whether a warning is + warn_on_score_mismatch: GEPA (currently) expects the metric to return the same module-level score when + called with and without the pred_name. This flag (defaults to True) determines whether a warning is raised if a mismatch in module-level and predictor-level score is detected. + enable_tool_optimization: Whether to enable joint optimization of tool-using modules. + When enabled, GEPA jointly optimizes predictor instructions and tool descriptions together + for both dspy.ReAct modules and custom predictors that use dspy.Tool. See the + [ReAct Component Optimization guide](https://dspy.ai/api/optimizers/GEPA/GEPA_Advanced/#react-component-optimization) + for details on when to use this feature and how it works. Default is False. seed: The random seed to use for reproducibility. Default is 0. gepa_kwargs: (Optional) Additional keyword arguments to pass directly to [gepa.optimize](https://github.com/gepa-ai/gepa/blob/main/src/gepa/api.py). Useful for accessing advanced GEPA features not directly exposed through DSPy's GEPA interface. @@ -307,21 +322,21 @@ def metric( Budget Configuration: Exactly one of `auto`, `max_full_evals`, or `max_metric_calls` must be provided. The `auto` parameter provides preset configurations: "light" for quick experimentation, "medium" for balanced optimization, and "heavy" for thorough optimization. - + Reflection Configuration: The `reflection_lm` parameter is required and should be a strong language model. GEPA performs best with models like `dspy.LM(model='gpt-5', temperature=1.0, max_tokens=32000)`. The reflection process analyzes failed examples to generate feedback for program improvement. - + Merge Configuration: GEPA can merge successful program variants using `use_merge=True`. The `max_merge_invocations` parameter controls how many merge attempts are made during optimization. - - Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and + + Evaluation Configuration: Use `num_threads` to parallelize evaluation. The `failure_score` and `perfect_score` parameters help GEPA understand your metric's range and optimize accordingly. - + Logging Configuration: Set `log_dir` to save detailed logs and enable checkpoint resuming. Use `track_stats=True` to access detailed optimization results via the `detailed_results` attribute. Enable `use_wandb=True` for experiment tracking and visualization. - + Reproducibility: Set `seed` to ensure consistent results across runs with the same configuration. """ def __init__( @@ -355,6 +370,7 @@ def __init__( wandb_init_kwargs: dict[str, Any] | None = None, track_best_outputs: bool = False, warn_on_score_mismatch: bool = True, + enable_tool_optimization: bool = False, use_mlflow: bool = False, # Reproducibility seed: int | None = 0, @@ -417,6 +433,7 @@ def __init__( self.wandb_api_key = wandb_api_key self.wandb_init_kwargs = wandb_init_kwargs self.warn_on_score_mismatch = warn_on_score_mismatch + self.enable_tool_optimization = enable_tool_optimization self.use_mlflow = use_mlflow if track_best_outputs: @@ -546,11 +563,112 @@ def feedback_fn( reflection_lm=self.reflection_lm, custom_instruction_proposer=self.custom_instruction_proposer, warn_on_score_mismatch=self.warn_on_score_mismatch, + enable_tool_optimization=self.enable_tool_optimization, reflection_minibatch_size=self.reflection_minibatch_size, ) # Instantiate GEPA with the simpler adapter-based API - base_program = {name: pred.signature.instructions for name, pred in student.named_predictors()} + base_program = {} + + # First, process ReAct modules to claim their predictors + if self.enable_tool_optimization: + for module_path, module in student.named_sub_modules(): + if not isinstance(module, ReAct): + continue + + # Verify DSPy's two-predictor ReAct design + assert hasattr(module, "extract") and hasattr(module.extract, "predict"), \ + f"ReAct module '{module_path}' missing extract.predict - DSPy design may have changed" + + # Get predictor names via object identity + extract_predictor = module.extract.predict + react_predictor = module.react + extract_predictor_name = None + react_predictor_name = None + for name, pred in student.named_predictors(): + if pred is extract_predictor: + extract_predictor_name = name + elif pred is react_predictor: + react_predictor_name = name + + # Use extract.predict as the key since it is the target predictor for feedback lookup + module_key = f"{REACT_MODULE_PREFIX}:{extract_predictor_name}" + + # Build JSON config with dynamic predictor names as keys + config = { + react_predictor_name: react_predictor.signature.instructions, + extract_predictor_name: extract_predictor.signature.instructions, + "tools": { + tool_name: { + "desc": tool.desc, + "args": tool.args, + } + for tool_name, tool in module.tools.items() + if tool_name != "finish" # Skip the built-in finish tool + } + } + + base_program[module_key] = json.dumps(config, indent=2) + else: + # Warn if ReAct modules found but tool optimization disabled + for module_path, module in student.named_sub_modules(): + if isinstance(module, ReAct): + logger.warning( + f"Detected ReAct module at '{module_path}'. Consider using " + "`enable_tool_optimization=True` to jointly optimize react instructions, " + "extract instructions, tool descriptions, and tool argument descriptions." + ) + + # Then, process individual predictors (skip if already part of a module config) + for name, pred in student.named_predictors(): + if self.enable_tool_optimization: + # Skip if predictor is part of a module config (e.g., ReAct) + found = False + for val in base_program.values(): + try: + config = json.loads(val) + if name in config: + found = True + break + except (json.JSONDecodeError, TypeError, ValueError): + pass + + if found: + continue + + # Detect tool-using predictors via type checking + def is_tool_field(annotation) -> bool: + """Check if a field annotation is Tool or contains Tool.""" + if annotation is Tool: + return True + origin = get_origin(annotation) + if origin is not None: + args = get_args(annotation) + for arg in args: + if is_tool_field(arg): # Recursive for nested types + return True + return False + + # Add tool module if predictor uses tools + if any(is_tool_field(field.annotation) for field in pred.signature.input_fields.values()): + module_key = f"{TOOL_MODULE_PREFIX}:{name}" + base_program[module_key] = json.dumps({ + name: pred.signature.instructions, + "tools": {} # Populated from traces + }, indent=2) + continue + + # Add regular predictor (no tool optimization or no tools detected) + base_program[name] = pred.signature.instructions + + # Log base_program keys for debugging + logger.info(f"Initialized base_program with {len(base_program)} components:") + for key in sorted(base_program.keys()): + if key.startswith(REACT_MODULE_PREFIX): + logger.info(f" {key}: ") + else: + logger.info(f" {key}: ") + gepa_result: GEPAResult = optimize( seed_candidate=base_program, trainset=trainset, diff --git a/dspy/teleprompt/gepa/gepa_utils.py b/dspy/teleprompt/gepa/gepa_utils.py index d2e6772cef..9aa7127489 100644 --- a/dspy/teleprompt/gepa/gepa_utils.py +++ b/dspy/teleprompt/gepa/gepa_utils.py @@ -1,3 +1,4 @@ +import json import logging import random from typing import Any, Callable, Protocol, TypedDict @@ -9,12 +10,19 @@ from dspy.adapters.chat_adapter import ChatAdapter from dspy.adapters.types import History from dspy.adapters.types.base_type import Type +from dspy.adapters.types.tool import Tool from dspy.evaluate import Evaluate from dspy.primitives import Example, Prediction from dspy.teleprompt.bootstrap_trace import TraceData logger = logging.getLogger(__name__) + +# Constants for module optimization +REACT_MODULE_PREFIX = "react_module" +TOOL_MODULE_PREFIX = "tool_module" + + class LoggerAdapter: def __init__(self, logger: logging.Logger): self.logger = logger @@ -22,6 +30,7 @@ def __init__(self, logger: logging.Logger): def log(self, x: str): self.logger.info(x) + DSPyTrace = list[tuple[Any, dict[str, Any], Prediction]] @@ -31,15 +40,17 @@ class ReflectiveExample(TypedDict): Each example contains the predictor inputs, generated outputs, and feedback from evaluation. """ - Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) - Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string - Feedback: str # Always a string - from metric function or parsing error message + + Inputs: dict[str, Any] # Predictor inputs (may include str, dspy.Image, etc.) + Generated_Outputs: dict[str, Any] | str # Success: dict with output fields, Failure: error message string + Feedback: str # Always a string - from metric function or parsing error message class ScoreWithFeedback(Prediction): score: float feedback: str + class PredictorFeedbackFn(Protocol): def __call__( predictor_output: dict[str, Any], @@ -64,6 +75,7 @@ def __call__( """ ... + class DspyAdapter(GEPAAdapter[Example, TraceData, Prediction]): def __init__( self, @@ -77,6 +89,7 @@ def __init__( reflection_lm=None, custom_instruction_proposer: "ProposalFn | None" = None, warn_on_score_mismatch: bool = True, + enable_tool_optimization: bool = False, reflection_minibatch_size: int | None = None, ): self.student = student_module @@ -89,43 +102,173 @@ def __init__( self.reflection_lm = reflection_lm self.custom_instruction_proposer = custom_instruction_proposer self.warn_on_score_mismatch = warn_on_score_mismatch + self.enable_tool_optimization = enable_tool_optimization + + self.propose_new_texts = self._build_propose_new_texts() self.reflection_minibatch_size = reflection_minibatch_size + def _build_propose_new_texts(self): + """Build proposal function that routes components to appropriate proposers.""" + # Init instruction proposer (custom or default) if self.custom_instruction_proposer is not None: - # We are only overriding the propose_new_texts method when a custom - # instruction proposer is provided. Otherwise, we use the GEPA - # default propose_new_texts. + instruction_proposer = self.custom_instruction_proposer + else: + from gepa.strategies.instruction_proposal import InstructionProposalSignature - def custom_propose_new_texts( + def default_instruction_proposer( candidate: dict[str, str], reflective_dataset: dict[str, list[dict[str, Any]]], - components_to_update: list[str] + components_to_update: list[str], ) -> dict[str, str]: - if self.reflection_lm is not None: - with dspy.context(lm=self.reflection_lm): - return self.custom_instruction_proposer( + lm = self.reflection_lm or dspy.settings.lm + updated_components: dict[str, str] = {} + for name in components_to_update: + base_instruction = candidate[name] + dataset_with_feedback = reflective_dataset[name] + updated_components[name] = InstructionProposalSignature.run( + lm=(lambda x: lm(x)[0]), + input_dict={ + "current_instruction_doc": base_instruction, + "dataset_with_feedback": dataset_with_feedback, + }, + )["new_instruction"] + return updated_components + + instruction_proposer = default_instruction_proposer + + # Init tool module proposer if tool optimization is enabled + tool_module_proposer = None + if self.enable_tool_optimization: + from .instruction_proposal import ToolModuleProposer + + tool_module_proposer = ToolModuleProposer() + + def propose_component_texts( + candidate: dict[str, str], + reflective_dataset: dict[str, list[dict[str, Any]]], + components_to_update: list[str], + ) -> dict[str, str]: + # If custom proposer provided, override everything with custom proposer + if self.custom_instruction_proposer: + with dspy.context(lm=self.reflection_lm or dspy.settings.lm): + return instruction_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=components_to_update, + ) + + # Otherwise, route to appropriate proposers + # Separate into two categories: components with tools vs regular instructions + tool_module_components = [] + instruction_components = [] + + for c in components_to_update: + if c.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)): + tool_module_components.append(c) + else: + instruction_components.append(c) + + results: dict[str, str] = {} + + with dspy.context(lm=self.reflection_lm or dspy.settings.lm): + # Handle regular instruction components + if instruction_components: + results.update( + instruction_proposer( candidate=candidate, reflective_dataset=reflective_dataset, - components_to_update=components_to_update + components_to_update=instruction_components, ) - else: - return self.custom_instruction_proposer( - candidate=candidate, - reflective_dataset=reflective_dataset, - components_to_update=components_to_update ) - self.propose_new_texts = custom_propose_new_texts + # Handle components with tools (ReAct and Tool modules) + if tool_module_components: + results.update( + tool_module_proposer( + candidate=candidate, + reflective_dataset=reflective_dataset, + components_to_update=tool_module_components, + ) + ) - # Cache predictor names/signatures - self.named_predictors = list(self.student.named_predictors()) + return results + return propose_component_texts def build_program(self, candidate: dict[str, str]): new_prog = self.student.deepcopy() + + # Start with plain string instructions from candidate + improved_predictors = { + k: v for k, v in candidate.items() + if not k.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)) + } + + improved_tools = {} + if self.enable_tool_optimization: + for key, value in candidate.items(): + if not key.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)): + continue + + config = json.loads(value) + + for pred_name, instruction in config.items(): + if isinstance(instruction, str): + improved_predictors[pred_name] = instruction + + improved_tools.update(config.get("tools", {})) + + # Update predictor instructions for name, pred in new_prog.named_predictors(): - if name in candidate: - pred.signature = pred.signature.with_instructions(candidate[name]) + if name in improved_predictors: + pred.signature = pred.signature.with_instructions(improved_predictors[name]) + + # Update tool descriptions + if improved_tools: + def collect_tools(obj): + all_tools = {} + visited = set() + + def traverse(o): + if id(o) in visited or not hasattr(o, "__dict__"): + return + visited.add(id(o)) + + for attr_val in o.__dict__.values(): + if isinstance(attr_val, Tool): + all_tools[attr_val.name] = attr_val + elif isinstance(attr_val, list): + for item in attr_val: + if isinstance(item, Tool): + all_tools[item.name] = item + elif isinstance(attr_val, dict): + for item in attr_val.values(): + if isinstance(item, Tool): + all_tools[item.name] = item + elif isinstance(attr_val, dspy.Module): + traverse(attr_val) + + traverse(obj) + return all_tools + + all_tools = collect_tools(new_prog) + + for tool_name, tool_config in improved_tools.items(): + if tool_name not in all_tools: + continue + + tool = all_tools[tool_name] + + # Update tool description if present. + if tool_config.get("desc") is not None: + tool.desc = tool_config["desc"] + + # Update arg descriptions if present. + args_schema = tool_config.get("args") or {} + for arg_name, arg_schema in args_schema.items(): + if arg_schema.get("description") is not None: + tool.args[arg_name]["description"] = arg_schema["description"] + return new_prog def evaluate(self, batch, candidate, capture_traces=False): @@ -176,19 +319,50 @@ def evaluate(self, batch, candidate, capture_traces=False): scores = [s["score"] if hasattr(s, "score") else s for s in scores] return EvaluationBatch(outputs=outputs, scores=scores, trajectories=None) - def make_reflective_dataset(self, candidate, eval_batch, components_to_update) -> dict[str, list[ReflectiveExample]]: + def make_reflective_dataset( + self, candidate, eval_batch, components_to_update + ) -> dict[str, list[ReflectiveExample]]: from dspy.teleprompt.bootstrap_trace import FailedPrediction program = self.build_program(candidate) ret_d: dict[str, list[ReflectiveExample]] = {} + + # collect unique tools from traces for each tool-using predictor, serialize to candidate at end + tools_by_predictor: dict[str, dict[str, Tool]] = {} + for pred_name in components_to_update: + # Extract predictor name from component key + if pred_name.startswith(REACT_MODULE_PREFIX): + target_name = pred_name.removeprefix(f"{REACT_MODULE_PREFIX}:") + + elif pred_name.startswith(TOOL_MODULE_PREFIX): + target_name = pred_name.removeprefix(f"{TOOL_MODULE_PREFIX}:") + tools_by_predictor[pred_name] = {} + + # Helper function for extracting tools (only needed for tool modules) + def extract_tools_from_value(value, tools_dict): + """Extract Tool objects from value (handles single, list, dict).""" + if isinstance(value, Tool): + tools_dict[value.name] = value + elif isinstance(value, (list, tuple, set)): + for item in value: + extract_tools_from_value(item, tools_dict) + elif isinstance(value, dict): + for item in value.values(): + extract_tools_from_value(item, tools_dict) + + else: + target_name = pred_name + + # Find the predictor object module = None for name, m in program.named_predictors(): - if name == pred_name: + if name == target_name: module = m break - assert module is not None + assert module is not None, f"Predictor not found: {target_name}" + # Create reflective examples from traces items: list[ReflectiveExample] = [] for data in eval_batch.trajectories or []: trace = data["trace"] @@ -204,16 +378,31 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - if len(trace_instances) == 0: continue - selected = None - for t in trace_instances: - if isinstance(t[2], FailedPrediction): - selected = t - break + # Extract tools that are used in the trace instances + if pred_name.startswith(TOOL_MODULE_PREFIX): + for t in trace_instances: + trace_inputs = t[1] + for input_value in trace_inputs.values(): + extract_tools_from_value(input_value, tools_by_predictor[pred_name]) - if selected is None: - if isinstance(prediction, FailedPrediction): - continue - selected = self.rng.choice(trace_instances) + # TODO: Workaround for ReAct's multiple predictor calls with partial trajectories. + # Using last trace ensures full aggregated trajectory (same as extract predictor). + # After PR #8999 merges (https://github.com/stanfordnlp/dspy/pull/8999), test if we can + # remove this and use extract predictor trace directly like other modules traces. + if pred_name.startswith(REACT_MODULE_PREFIX): + selected = trace_instances[-1] + + else: + selected = None + for t in trace_instances: + if isinstance(t[2], FailedPrediction): + selected = t + break + + if selected is None: + if isinstance(prediction, FailedPrediction): + continue + selected = self.rng.choice(trace_instances) inputs = selected[1] outputs = selected[2] @@ -265,7 +454,8 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - d["Feedback"] = "Your output failed to parse. Follow this structure:\n" + structure_instruction # d['score'] = self.failure_score else: - feedback_fn = self.feedback_map[pred_name] + # Use actual predictor name for feedback lookup + feedback_fn = self.feedback_map[target_name] fb = feedback_fn( predictor_output=outputs, predictor_inputs=inputs, @@ -283,10 +473,26 @@ def make_reflective_dataset(self, candidate, eval_batch, components_to_update) - items.append(d) if len(items) == 0: - # raise Exception(f"No valid predictions found for module {module.signature}.") + logger.warning(f" No valid reflective examples found for {pred_name}") continue + ret_d[pred_name] = items + # Update candidate configs with extracted tools (after all traces processed) + for pred_name, tools_dict in tools_by_predictor.items(): + if not tools_dict: + continue + + config = json.loads(candidate[pred_name]) + config["tools"] = { + tool_name: { + "desc": tool.desc, + "args": tool.args, + } + for tool_name, tool in tools_dict.items() + } + candidate[pred_name] = json.dumps(config, indent=2) + if len(ret_d) == 0: raise Exception("No valid predictions found for any module.") diff --git a/dspy/teleprompt/gepa/instruction_proposal.py b/dspy/teleprompt/gepa/instruction_proposal.py index 23810b9a02..4207319ccb 100644 --- a/dspy/teleprompt/gepa/instruction_proposal.py +++ b/dspy/teleprompt/gepa/instruction_proposal.py @@ -1,10 +1,14 @@ +import json +import logging from typing import Any from gepa.core.adapter import ProposalFn import dspy from dspy.adapters.types.base_type import Type -from dspy.teleprompt.gepa.gepa_utils import ReflectiveExample +from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX, ReflectiveExample + +logger = logging.getLogger(__name__) class GenerateEnhancedMultimodalInstructionFromFeedback(dspy.Signature): @@ -310,3 +314,194 @@ def __call__( updated_components[component_name] = new_instruction return updated_components + +class GenerateImprovedToolModuleDescriptionsFromFeedback(dspy.Signature): + """Improve a tool-using module based on execution examples and feedback. + + These components are progressively optimized - refine what needs improvement. + Analyze the examples_with_feedback to identify successful patterns and failure causes. + Generate improved texts to help the module succeed on similar tasks. + Place improved texts at their appropriate level of abstraction and/or specificity. + """ + + current_predictor_instruction = dspy.InputField( + desc="Current instruction guiding the predictor" + ) + current_tools = dspy.InputField( + annotation=list[dspy.Tool], + desc="Available tools with their complete schemas" + ) + examples_with_feedback = dspy.InputField( + desc="Execution examples with feedback showing successes and failures" + ) + + improved_predictor_instruction: str | None = dspy.OutputField( + desc="Improved instruction for the predictor", + default=None + ) + + + + + +class ToolModuleProposer(ProposalFn): + """Proposer for optimizing tool-using module configurations. + + Supports two types of modules: + - Tool modules (1 predictor): Optimizes predictor instruction and tool descriptions + - ReAct modules (2 predictors): Jointly optimizes react instruction, extract instruction, and tool descriptions + + Uses dynamic signature generation to create output fields for each tool and parameter, + enabling the reflection LM to optimize all components cohesively based on execution feedback. + + This joint optimization approach allows the LM to see how instructions and tool descriptions + work together, leading to more coherent improvements than optimizing each component separately. + """ + + def __call__( + self, + candidate: dict[str, str], + reflective_dataset: dict[str, list[ReflectiveExample]], + components_to_update: list[str], + ) -> dict[str, str]: + """Optimize tool-using module components. + + Args: + candidate: Current component name -> JSON config mapping + reflective_dataset: Component name -> list of reflective examples + components_to_update: List of tool-using module component names to update + + Returns: + dict: Mapping of component names to improved JSON configs + """ + + updated_components = {} + + for module_key in components_to_update: + if module_key not in candidate or module_key not in reflective_dataset: + logger.warning(f"Skipping {module_key}: not in candidate={module_key not in candidate}, not in reflective_dataset={module_key not in reflective_dataset}") + continue + + current_module_config = json.loads(candidate[module_key]) + + # Predictor keys: 1 for tool modules, 2 for ReAct modules (extra extract predictor) + predictor_keys = [k for k, v in current_module_config.items() if isinstance(v, str)] + primary_predictor_key = predictor_keys[0] + extract_predictor_key = predictor_keys[1] if module_key.startswith(REACT_MODULE_PREFIX) else None + + # Reconstruct Tool objects from JSON (func is placeholder since it can't be serialized) + current_tools_dict = current_module_config.get("tools", {}) + tools_list = [] + for tool_name, tool_info in current_tools_dict.items(): + tool = dspy.Tool( + func=lambda *args, **kwargs: None, # Placeholder - Tool requires Callable, but only schema is used + name=tool_name, + desc=tool_info.get("desc", ""), + ) + tool.args = tool_info.get("args", {}) + tools_list.append(tool) + + # Build dynamic signature with tool-specific output fields + signature = GenerateImprovedToolModuleDescriptionsFromFeedback + + for tool in tools_list: + tool_name = tool.name + tool_info = current_tools_dict[tool_name] + + signature = signature.append( + f"improved_tool_{tool_name}_desc", + dspy.OutputField( + desc=f"Concise description of tool '{tool_name}'", + default=None + ) + ) + + for arg_name in tool_info["args"].keys(): + signature = signature.append( + f"improved_tool_{tool_name}_arg_{arg_name}_desc", + dspy.OutputField( + desc=f"Concise description of tool '{tool_name}' parameter '{arg_name}'", + default=None + ) + ) + + kwargs = { + "current_predictor_instruction": current_module_config[primary_predictor_key], + "current_tools": tools_list, + "examples_with_feedback": self._format_examples(reflective_dataset[module_key]), + } + # If module has extract predictor, add extract fields + if extract_predictor_key is not None: + signature = signature.append( + "current_extract_instruction", + dspy.InputField(desc="Current instruction for extraction predictor") + ) + signature = signature.append( + "improved_extract_instruction", + dspy.OutputField(desc="Improved instruction for extraction", default=None) + ) + kwargs["current_extract_instruction"] = current_module_config[extract_predictor_key] + + propose_descriptions = dspy.Predict(signature) + + result = propose_descriptions(**kwargs) + + # Build improved config (reflection LM returns None to keep original, or new text) + improved_module_config = {} + + if result.improved_predictor_instruction is not None: + improved_module_config[primary_predictor_key] = result.improved_predictor_instruction + + if extract_predictor_key is not None and result.improved_extract_instruction is not None: + improved_module_config[extract_predictor_key] = result.improved_extract_instruction + + improved_module_config["tools"] = {} + for tool_name, tool_info in current_tools_dict.items(): + # Update tool description if LM proposed a change + improved_tool_desc = getattr(result, f"improved_tool_{tool_name}_desc", None) + if improved_tool_desc is not None: + tool_info["desc"] = improved_tool_desc + + # Update arg descriptions if LM proposed changes + for arg_name in tool_info["args"].keys(): + improved_tool_arg_desc = getattr(result, f"improved_tool_{tool_name}_arg_{arg_name}_desc", None) + if improved_tool_arg_desc is not None: + tool_info["args"][arg_name]["description"] = improved_tool_arg_desc + + improved_module_config["tools"][tool_name] = tool_info + + updated_components[module_key] = json.dumps(improved_module_config, indent=2) + + return updated_components + + def _format_examples(self, reflective_dataset: list[ReflectiveExample]) -> str: + """Format reflective examples using GEPA's markdown structure.""" + + def render_value(value, level=3): + if isinstance(value, dict): + s = "" + for key, val in value.items(): + s += f"{'#' * level} {key}\n" + s += render_value(val, min(level + 1, 6)) + if not value: + s += "\n" + return s + if isinstance(value, (list, tuple)): + s = "" + for index, item in enumerate(value): + s += f"{'#' * level} Item {index + 1}\n" + s += render_value(item, min(level + 1, 6)) + if not value: + s += "\n" + return s + return f"{str(value).strip()}\n\n" + + def convert_sample_to_markdown(sample, example_num): + s = f"# Example {example_num}\n" + for key, val in sample.items(): + s += f"## {key}\n" + s += render_value(val, level=3) + return s + + formatted_parts = [convert_sample_to_markdown(example, i + 1) for i, example in enumerate(reflective_dataset)] + return "\n\n".join(formatted_parts) diff --git a/tests/teleprompt/test_gepa_tool_optimization.py b/tests/teleprompt/test_gepa_tool_optimization.py new file mode 100644 index 0000000000..59a77f46ad --- /dev/null +++ b/tests/teleprompt/test_gepa_tool_optimization.py @@ -0,0 +1,496 @@ +"""Tests for GEPA's tool optimization (ReAct modules and custom tool modules). + +Tests the generic tool optimization that works with ANY module using dspy.Tool, +including dspy.ReAct and custom modules. + +Test categories: +1. Detection - Compile-time detection of tool-using modules +2. Application - build_program applies optimized instructions and tool descriptions + +DSPy ReAct Design Note: + DSPy's ReAct uses two predictors: + - react: reasoning/acting loop + - extract: structured output synthesis + + We optimize extract.predict as it's called once with the complete trajectory + and produces all output fields. +""" + +import json + +import gepa +from gepa import optimize as gepa_optimize + +import dspy +from dspy.teleprompt.gepa.gepa_utils import REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX, DspyAdapter +from dspy.utils.dummies import DummyLM + + +# Test tool fixtures +def search(query: str) -> str: + """Test search tool.""" + return f"Search: {query}" + + +def calculate(expr: str) -> str: + """Test calculator tool.""" + return str(eval(expr)) + + +def analyze(data: str) -> str: + """Test analyzer tool.""" + return f"Analysis: {data}" + + +def setup_seed_candidate_capture(monkeypatch): + """Capture seed_candidate dict passed to gepa.optimize.""" + captured = {} + + def capture_optimize(seed_candidate, **kwargs): + captured.update(seed_candidate) + return gepa_optimize(seed_candidate=seed_candidate, **kwargs) + + monkeypatch.setattr(gepa, "optimize", capture_optimize) + return captured + + +def create_optimizer(task_responses, reflection_responses): + """Create GEPA optimizer with explicit LM responses. + + Args: + task_responses: List of dicts for task LM (e.g., [{"answer": "test"}]) + reflection_responses: List of dicts for reflection LM + + Returns: + tuple: (optimizer, trainset) + """ + task_lm = DummyLM(task_responses) + reflection_lm = DummyLM(reflection_responses) + + dspy.settings.configure(lm=task_lm) + + optimizer = dspy.GEPA( + metric=lambda example, pred, trace=None, pred_name=None, pred_trace=None: dspy.Prediction(score=0.5, feedback="ok"), + reflection_lm=reflection_lm, + max_metric_calls=2, + enable_tool_optimization=True, + ) + + trainset = [dspy.Example(query="test", answer="test").with_inputs("query")] + return optimizer, trainset + + +def get_predictor_name(program, predictor): + """Find predictor name by object identity in named_predictors(). + + Args: + program: DSPy module + predictor: Predictor object to find + + Returns: + str: Predictor name (e.g., "pred", "agent.pred") + """ + for name, pred in program.named_predictors(): + if pred is predictor: + return name + raise ValueError(f"Predictor not found: {predictor}") + + +def test_detect_single_tool(monkeypatch): + """Detect single tool in custom module.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class AgentSignature(dspy.Signature): + """Answer questions using tools.""" + query: str = dspy.InputField() + tool: dspy.Tool = dspy.InputField() + answer: str = dspy.OutputField() + + class Agent(dspy.Module): + def __init__(self): + super().__init__() + self.tool = dspy.Tool(search, name="search", desc="Search tool") + self.pred = dspy.Predict(AgentSignature) + + def forward(self, query): + return self.pred(query=query, tool=self.tool) + + + program = Agent() + optimizer, trainset = create_optimizer( + task_responses=[{"answer": "test"}] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "optimized", + "improved_tool_search_desc": "optimized search desc", + "improved_tool_search_arg_query_desc": "optimized query desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + predictor_name = get_predictor_name(program, program.pred) + component_key = f"{TOOL_MODULE_PREFIX}:{predictor_name}" + assert component_key in seed_candidate + + tool_config = json.loads(seed_candidate[component_key]) + assert predictor_name in tool_config + assert "tools" in tool_config + + +def test_detect_multiple_tools(monkeypatch): + """Detect multiple tools in custom module.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class AgentSignature(dspy.Signature): + """Answer questions using multiple tools.""" + query: str = dspy.InputField() + tools: list[dspy.Tool] = dspy.InputField() + answer: str = dspy.OutputField() + + class Agent(dspy.Module): + def __init__(self): + super().__init__() + self.tools = [ + dspy.Tool(search, name="search", desc="Search tool"), + dspy.Tool(calculate, name="calc", desc="Calculator"), + ] + self.pred = dspy.Predict(AgentSignature) + + def forward(self, query): + return self.pred(query=query, tools=self.tools) + + program = Agent() + optimizer, trainset = create_optimizer( + task_responses=[{"answer": "test"}] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "optimized", + "improved_tool_search_desc": "optimized search desc", + "improved_tool_search_arg_query_desc": "optimized query desc", + "improved_tool_calc_desc": "optimized calc desc", + "improved_tool_calc_arg_expr_desc": "optimized expr desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + predictor_name = get_predictor_name(program, program.pred) + component_key = f"{TOOL_MODULE_PREFIX}:{predictor_name}" + assert component_key in seed_candidate + + tool_config = json.loads(seed_candidate[component_key]) + assert predictor_name in tool_config + assert "tools" in tool_config + + +def test_skip_predictor_without_tools(monkeypatch): + """Skip predictors without Tool annotations.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class PlainSignature(dspy.Signature): + """Answer questions.""" + query: str = dspy.InputField() + answer: str = dspy.OutputField() + + class PlainAgent(dspy.Module): + def __init__(self): + super().__init__() + self.pred = dspy.Predict(PlainSignature) + + def forward(self, query): + return self.pred(query=query) + + program = PlainAgent() + optimizer, trainset = create_optimizer( + task_responses=[{"answer": "test"}] * 20, # Repeat for GEPA iterations + reflection_responses=[{"improved_instruction": "optimized"}] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + predictor_name = get_predictor_name(program, program.pred) + assert predictor_name in seed_candidate + + # Should be plain string instruction, not JSON config + instruction = seed_candidate[predictor_name] + assert isinstance(instruction, str) + + +def test_apply_optimized_tool_descriptions(): + """Apply optimized tool descriptions via build_program.""" + + class AgentSignature(dspy.Signature): + """Answer using tools.""" + query: str = dspy.InputField() + tool: dspy.Tool = dspy.InputField() + answer: str = dspy.OutputField() + + class Agent(dspy.Module): + def __init__(self): + super().__init__() + self.tool = dspy.Tool(search, name="search", desc="Original description") + self.pred = dspy.Predict(AgentSignature) + + def forward(self, query): + return self.pred(query=query, tool=self.tool) + + program = Agent() + predictor_name = get_predictor_name(program, program.pred) + component_key = f"{TOOL_MODULE_PREFIX}:{predictor_name}" + + optimized_candidate = { + component_key: json.dumps({ + predictor_name: "OPTIMIZED: Answer using tools", + "tools": { + "search": { + "desc": "OPTIMIZED: Search description", + "args": {"query": {"type": "string", "description": "Search query"}}, + } + } + }) + } + + # Apply optimizations + adapter = DspyAdapter( + student_module=program, + metric_fn=lambda example, pred, trace=None: 0.5, + feedback_map={}, + enable_tool_optimization=True, + ) + rebuilt = adapter.build_program(optimized_candidate) + + assert rebuilt.pred.signature.instructions == "OPTIMIZED: Answer using tools" + assert rebuilt.tool.desc == "OPTIMIZED: Search description" + assert rebuilt.tool.args["query"]["description"] == "Search query" + + # Original unchanged + assert program.pred.signature.instructions != "OPTIMIZED: Answer using tools" + assert program.tool.desc == "Original description" + + +def test_detect_react_module(monkeypatch): + """Detect ReAct module with tools.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + program = dspy.ReAct("question -> answer", tools=[search]) + optimizer, trainset = create_optimizer( + task_responses=[ + {"next_thought": "I should search", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Based on search", "answer": "test"}, + ] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "optimized react", + "improved_extract_instruction": "optimized extract", + "improved_tool_search_desc": "optimized search desc", + "improved_tool_search_arg_query_desc": "optimized query desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + # Verify detection - use extract.predict as primary (for tracing) + extract_name = get_predictor_name(program, program.extract.predict) + component_key = f"{REACT_MODULE_PREFIX}:{extract_name}" + assert component_key in seed_candidate + + tool_config = json.loads(seed_candidate[component_key]) + assert "tools" in tool_config + + +def test_detect_multiple_react_modules(monkeypatch): + """Detect multiple ReAct modules in workflow.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class Workflow(dspy.Module): + def __init__(self): + super().__init__() + self.searcher = dspy.ReAct("query -> results", tools=[search]) + self.analyzer = dspy.ReAct("data -> analysis", tools=[analyze]) + + def forward(self, query): + results = self.searcher(query=query) + return self.analyzer(data=results.results) + + program = Workflow() + optimizer, trainset = create_optimizer( + task_responses=[ + {"next_thought": "Searching", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Found results", "results": "data"}, + {"next_thought": "Analyzing", "next_tool_name": "analyze", "next_tool_args": {"data": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Analyzed", "analysis": "result"}, + ] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "opt react search", + "improved_extract_instruction": "opt extract search", + "improved_tool_search_desc": "opt search desc", + "improved_tool_search_arg_query_desc": "opt query desc" + }, + { + "improved_predictor_instruction": "opt react analyze", + "improved_extract_instruction": "opt extract analyze", + "improved_tool_analyze_desc": "opt analyze desc", + "improved_tool_analyze_arg_data_desc": "opt data desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + # Verify both detected - use extract.predict as primary (for tracing) + searcher_name = get_predictor_name(program, program.searcher.extract.predict) + analyzer_name = get_predictor_name(program, program.analyzer.extract.predict) + + searcher_key = f"{REACT_MODULE_PREFIX}:{searcher_name}" + analyzer_key = f"{REACT_MODULE_PREFIX}:{analyzer_name}" + + assert searcher_key in seed_candidate + assert analyzer_key in seed_candidate + + +def test_apply_optimized_react_descriptions(): + """Apply optimized tool descriptions to ReAct modules.""" + + program = dspy.ReAct("question -> answer", tools=[search]) + + # Create mock optimized candidate - use extract.predict as primary (for tracing) + react_name = get_predictor_name(program, program.react) + extract_predict_name = get_predictor_name(program, program.extract.predict) + + component_key = f"{REACT_MODULE_PREFIX}:{extract_predict_name}" + + optimized_candidate = { + component_key: json.dumps({ + react_name: "OPTIMIZED: React instruction", + extract_predict_name: "OPTIMIZED: Extract instruction", + "tools": { + "search": { + "desc": "OPTIMIZED: Search tool", + "args": {"query": {"type": "string"}}, + } + } + }) + } + + # Apply optimizations + adapter = DspyAdapter( + student_module=program, + metric_fn=lambda example, pred, trace=None: 0.5, + feedback_map={}, + enable_tool_optimization=True, + ) + rebuilt = adapter.build_program(optimized_candidate) + + # Verify instructions updated + assert rebuilt.react.signature.instructions == "OPTIMIZED: React instruction" + assert rebuilt.extract.predict.signature.instructions == "OPTIMIZED: Extract instruction" + + # Verify tool updated + assert rebuilt.tools["search"].desc == "OPTIMIZED: Search tool" + + +def test_detect_nested_react_modules(monkeypatch): + """Detect ReAct modules in nested program structure.""" + seed_candidate = setup_seed_candidate_capture(monkeypatch) + + class Worker(dspy.Module): + def __init__(self): + super().__init__() + self.react = dspy.ReAct("task -> result", tools=[analyze]) + + def forward(self, task): + return self.react(task=task) + + class Orchestrator(dspy.Module): + def __init__(self): + super().__init__() + self.searcher = dspy.ReAct("query -> results", tools=[search]) + self.worker = Worker() + + def forward(self, query): + results = self.searcher(query=query) + return self.worker(task=results.results) + + program = Orchestrator() + optimizer, trainset = create_optimizer( + task_responses=[ + {"next_thought": "Search", "next_tool_name": "search", "next_tool_args": {"query": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Found", "results": "data"}, + {"next_thought": "Analyze", "next_tool_name": "analyze", "next_tool_args": {"data": "test"}}, + {"next_thought": "Done", "next_tool_name": "finish", "next_tool_args": {}}, + {"reasoning": "Analyzed", "result": "final"}, + ] * 20, # Repeat for GEPA iterations + reflection_responses=[ + { + "improved_predictor_instruction": "opt react search", + "improved_extract_instruction": "opt extract search", + "improved_tool_search_desc": "opt search desc", + "improved_tool_search_arg_query_desc": "opt query desc" + }, + { + "improved_predictor_instruction": "opt react analyze", + "improved_extract_instruction": "opt extract analyze", + "improved_tool_analyze_desc": "opt analyze desc", + "improved_tool_analyze_arg_data_desc": "opt data desc" + } + ] * 20 # Repeat for GEPA iterations + ) + optimizer.compile(program, trainset=trainset, valset=trainset) + + # Verify nested modules detected with full paths - use extract.predict as primary (for tracing) + searcher_name = get_predictor_name(program, program.searcher.extract.predict) + worker_extract_name = get_predictor_name(program, program.worker.react.extract.predict) + + searcher_key = f"{REACT_MODULE_PREFIX}:{searcher_name}" + worker_key = f"{REACT_MODULE_PREFIX}:{worker_extract_name}" + + assert searcher_key in seed_candidate + assert worker_key in seed_candidate + + # Verify full paths preserved (not truncated) + assert "searcher" in searcher_name # Contains parent path + assert "worker" in worker_extract_name # Contains nested path + + +def test_selective_optimization_with_none_returns(): + """Verify selective optimization when reflection LM returns None for some fields.""" + + program = dspy.ReAct("question -> answer", tools=[search, calculate]) + + react_name = get_predictor_name(program, program.react) + extract_name = get_predictor_name(program, program.extract.predict) + component_key = f"{REACT_MODULE_PREFIX}:{extract_name}" + + # Mock selective optimization (only react instruction and search tool updated) + optimized_candidate = { + component_key: json.dumps({ + react_name: "OPTIMIZED: React instruction", + extract_name: program.extract.predict.signature.instructions, + "tools": { + "search": { + "desc": "OPTIMIZED: Search tool", + "args": {"query": {"type": "string"}}, + } + } + }) + } + + adapter = DspyAdapter( + student_module=program, + metric_fn=lambda example, pred, trace=None: 0.5, + feedback_map={}, + enable_tool_optimization=True, + ) + rebuilt = adapter.build_program(optimized_candidate) + + # Verify selective updates + assert rebuilt.react.signature.instructions == "OPTIMIZED: React instruction" + assert rebuilt.extract.predict.signature.instructions == program.extract.predict.signature.instructions + assert rebuilt.tools["search"].desc == "OPTIMIZED: Search tool" + + # Original unchanged (calculate not in optimized candidate) + assert rebuilt.tools["calculate"].desc == program.tools["calculate"].desc