|
12 | 12 | from dspy.adapters.types.base_type import Type |
13 | 13 | from dspy.adapters.types.tool import Tool |
14 | 14 | from dspy.evaluate import Evaluate |
15 | | -from dspy.predict.react import ReAct |
16 | 15 | from dspy.primitives import Example, Prediction |
17 | 16 | from dspy.teleprompt.bootstrap_trace import TraceData |
18 | 17 |
|
@@ -106,9 +105,6 @@ def __init__( |
106 | 105 |
|
107 | 106 | self.propose_new_texts = self._build_propose_new_texts() |
108 | 107 |
|
109 | | - # Cache predictor names/signatures |
110 | | - self.named_predictors = list(self.student.named_predictors()) |
111 | | - |
112 | 108 | def _build_propose_new_texts(self): |
113 | 109 | """Build proposal function that routes components to appropriate proposers.""" |
114 | 110 | # Init instruction proposer (custom or default) |
@@ -202,79 +198,79 @@ def propose_component_texts( |
202 | 198 | def build_program(self, candidate: dict[str, str]): |
203 | 199 | new_prog = self.student.deepcopy() |
204 | 200 |
|
205 | | - # Apply regular predictor instructions |
206 | | - for name, pred in new_prog.named_predictors(): |
207 | | - if name in candidate: |
208 | | - pred.signature = pred.signature.with_instructions(candidate[name]) |
| 201 | + # Start with plain string instructions from candidate |
| 202 | + improved_predictors = { |
| 203 | + k: v for k, v in candidate.items() |
| 204 | + if not k.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)) |
| 205 | + } |
209 | 206 |
|
210 | | - # Apply ReAct module updates (JSON configs for ReAct modules: react, extract, tools) |
| 207 | + improved_tools = {} |
211 | 208 | if self.enable_tool_optimization: |
212 | | - for _, module in new_prog.named_sub_modules(): |
213 | | - # Only process ReAct modules |
214 | | - if not isinstance(module, ReAct): |
| 209 | + for key, value in candidate.items(): |
| 210 | + if not key.startswith((REACT_MODULE_PREFIX, TOOL_MODULE_PREFIX)): |
215 | 211 | continue |
216 | 212 |
|
217 | | - # Find module key using extract predictor name |
218 | | - extract_predictor = module.extract.predict |
219 | | - module_key = None |
| 213 | + config = json.loads(value) |
| 214 | + |
| 215 | + # Parse module configs and override predictor instructions |
| 216 | + for pred_name, instruction in config.items(): |
| 217 | + if isinstance(instruction, str): |
| 218 | + improved_predictors[pred_name] = instruction |
220 | 219 |
|
221 | | - for name, pred in new_prog.named_predictors(): |
222 | | - if pred is extract_predictor: |
223 | | - module_key = f"{REACT_MODULE_PREFIX}:{name}" |
224 | | - break |
| 220 | + if "tools" in config: |
| 221 | + improved_tools.update(config["tools"]) |
225 | 222 |
|
226 | | - # Check if this module was optimized |
227 | | - if module_key is None or module_key not in candidate: |
| 223 | + # Update predictor instructions |
| 224 | + for name, pred in new_prog.named_predictors(): |
| 225 | + if name in improved_predictors: |
| 226 | + pred.signature = pred.signature.with_instructions(improved_predictors[name]) |
| 227 | + |
| 228 | + # Update tool descriptions |
| 229 | + if improved_tools: |
| 230 | + def collect_tools(obj): |
| 231 | + all_tools = {} |
| 232 | + visited = set() |
| 233 | + |
| 234 | + def traverse(o): |
| 235 | + if id(o) in visited or not hasattr(o, "__dict__"): |
| 236 | + return |
| 237 | + visited.add(id(o)) |
| 238 | + |
| 239 | + for attr_val in o.__dict__.values(): |
| 240 | + if isinstance(attr_val, Tool): |
| 241 | + all_tools[attr_val.name] = attr_val |
| 242 | + elif isinstance(attr_val, list): |
| 243 | + for item in attr_val: |
| 244 | + if isinstance(item, Tool): |
| 245 | + all_tools[item.name] = item |
| 246 | + elif isinstance(attr_val, dict): |
| 247 | + for item in attr_val.values(): |
| 248 | + if isinstance(item, Tool): |
| 249 | + all_tools[item.name] = item |
| 250 | + elif isinstance(attr_val, dspy.Module): |
| 251 | + traverse(attr_val) |
| 252 | + |
| 253 | + traverse(obj) |
| 254 | + return all_tools |
| 255 | + |
| 256 | + all_tools = collect_tools(new_prog) |
| 257 | + |
| 258 | + for tool_name, tool_config in improved_tools.items(): |
| 259 | + if tool_name not in all_tools: |
228 | 260 | continue |
229 | 261 |
|
230 | | - # Deserialize JSON containing optimized module configuration |
231 | | - try: |
232 | | - module_config = json.loads(candidate[module_key]) |
233 | | - logger.debug(f"Applying optimized module config to {module_key}") |
234 | | - |
235 | | - # Find predictor names for this module |
236 | | - react_pred_name = None |
237 | | - extract_pred_name = None |
238 | | - for pred_name, pred in new_prog.named_predictors(): |
239 | | - if pred is module.react: |
240 | | - react_pred_name = pred_name |
241 | | - elif pred is module.extract.predict: |
242 | | - extract_pred_name = pred_name |
243 | | - |
244 | | - # Apply react instruction using actual predictor name as key |
245 | | - if react_pred_name and react_pred_name in module_config: |
246 | | - module.react.signature = module.react.signature.with_instructions(module_config[react_pred_name]) |
247 | | - logger.debug(" Updated react instruction") |
248 | | - |
249 | | - # Apply extract instruction using actual predictor name as key |
250 | | - if extract_pred_name and extract_pred_name in module_config: |
251 | | - module.extract.predict.signature = module.extract.predict.signature.with_instructions(module_config[extract_pred_name]) |
252 | | - logger.debug(" Updated extract instruction") |
253 | | - |
254 | | - # Apply tool descriptions |
255 | | - if "tools" in module_config: |
256 | | - for tool_name, tool_config in module_config["tools"].items(): |
257 | | - tool = module.tools[tool_name] |
258 | | - |
259 | | - # Update tool description |
260 | | - if tool_config.get("desc"): |
261 | | - tool.desc = tool_config["desc"] |
262 | | - logger.debug(f" Updated tool '{tool_name}' description") |
263 | | - |
264 | | - # Update tool arg descriptions |
265 | | - arg_desc = tool_config.get("arg_desc") |
266 | | - if arg_desc: |
267 | | - tool.arg_desc = tool.arg_desc or {} |
268 | | - tool.arg_desc.update(arg_desc) |
269 | | - # Propagate to tool.args |
270 | | - for arg_name, description in arg_desc.items(): |
271 | | - if arg_name in tool.args: |
272 | | - tool.args[arg_name]["description"] = description |
273 | | - logger.debug(f" Updated tool '{tool_name}' arg descriptions: {list(arg_desc.keys())}") |
274 | | - |
275 | | - except json.JSONDecodeError as e: |
276 | | - logger.error(f"Failed to parse JSON config for {module_key}: {e}") |
277 | | - raise |
| 262 | + tool = all_tools[tool_name] |
| 263 | + |
| 264 | + if tool_config.get("desc"): |
| 265 | + tool.desc = tool_config["desc"] |
| 266 | + |
| 267 | + arg_desc = tool_config.get("arg_desc") |
| 268 | + if arg_desc: |
| 269 | + tool.arg_desc = tool.arg_desc or {} |
| 270 | + tool.arg_desc.update(arg_desc) |
| 271 | + for arg_name, description in arg_desc.items(): |
| 272 | + if arg_name in tool.args: |
| 273 | + tool.args[arg_name]["description"] = description |
278 | 274 |
|
279 | 275 | return new_prog |
280 | 276 |
|
|
0 commit comments