|
13 | 13 | import uuid |
14 | 14 | from typing import TYPE_CHECKING, Any, AsyncGenerator |
15 | 15 |
|
| 16 | +from ..experimental.hooks import AfterToolInvocationEvent, BeforeToolInvocationEvent |
| 17 | +from ..experimental.hooks.registry import get_registry |
16 | 18 | from ..telemetry.metrics import Trace |
17 | 19 | from ..telemetry.tracer import get_tracer |
18 | 20 | from ..tools.executor import run_tools, validate_and_prepare_tools |
@@ -271,46 +273,97 @@ def run_tool(agent: "Agent", tool_use: ToolUse, kwargs: dict[str, Any]) -> ToolG |
271 | 273 | The final tool result or an error response if the tool fails or is not found. |
272 | 274 | """ |
273 | 275 | logger.debug("tool_use=<%s> | streaming", tool_use) |
274 | | - tool_use_id = tool_use["toolUseId"] |
275 | 276 | tool_name = tool_use["name"] |
276 | 277 |
|
277 | 278 | # Get the tool info |
278 | 279 | tool_info = agent.tool_registry.dynamic_tools.get(tool_name) |
279 | 280 | tool_func = tool_info if tool_info is not None else agent.tool_registry.registry.get(tool_name) |
280 | 281 |
|
| 282 | + # Add standard arguments to kwargs for Python tools |
| 283 | + kwargs.update( |
| 284 | + { |
| 285 | + "model": agent.model, |
| 286 | + "system_prompt": agent.system_prompt, |
| 287 | + "messages": agent.messages, |
| 288 | + "tool_config": agent.tool_config, |
| 289 | + } |
| 290 | + ) |
| 291 | + |
| 292 | + before_event = get_registry(agent).invoke_callbacks( |
| 293 | + BeforeToolInvocationEvent( |
| 294 | + agent=agent, |
| 295 | + selected_tool=tool_func, |
| 296 | + tool_use=tool_use, |
| 297 | + kwargs=kwargs, |
| 298 | + ) |
| 299 | + ) |
| 300 | + |
281 | 301 | try: |
| 302 | + selected_tool = before_event.selected_tool |
| 303 | + tool_use = before_event.tool_use |
| 304 | + |
282 | 305 | # Check if tool exists |
283 | | - if not tool_func: |
284 | | - logger.error( |
285 | | - "tool_name=<%s>, available_tools=<%s> | tool not found in registry", |
286 | | - tool_name, |
287 | | - list(agent.tool_registry.registry.keys()), |
288 | | - ) |
289 | | - return { |
290 | | - "toolUseId": tool_use_id, |
| 306 | + if not selected_tool: |
| 307 | + if tool_func == selected_tool: |
| 308 | + logger.error( |
| 309 | + "tool_name=<%s>, available_tools=<%s> | tool not found in registry", |
| 310 | + tool_name, |
| 311 | + list(agent.tool_registry.registry.keys()), |
| 312 | + ) |
| 313 | + else: |
| 314 | + logger.debug( |
| 315 | + "tool_name=<%s>, tool_use_id=<%s> | a hook resulted in a non-existing tool call", |
| 316 | + tool_name, |
| 317 | + str(tool_use.get("toolUseId")), |
| 318 | + ) |
| 319 | + |
| 320 | + result: ToolResult = { |
| 321 | + "toolUseId": str(tool_use.get("toolUseId")), |
291 | 322 | "status": "error", |
292 | 323 | "content": [{"text": f"Unknown tool: {tool_name}"}], |
293 | 324 | } |
294 | | - # Add standard arguments to kwargs for Python tools |
295 | | - kwargs.update( |
296 | | - { |
297 | | - "model": agent.model, |
298 | | - "system_prompt": agent.system_prompt, |
299 | | - "messages": agent.messages, |
300 | | - "tool_config": agent.tool_config, |
301 | | - } |
302 | | - ) |
| 325 | + # for every Before event call, we need to have an AfterEvent call |
| 326 | + after_event = get_registry(agent).invoke_callbacks( |
| 327 | + AfterToolInvocationEvent( |
| 328 | + agent=agent, |
| 329 | + selected_tool=selected_tool, |
| 330 | + tool_use=tool_use, |
| 331 | + kwargs=kwargs, |
| 332 | + result=result, |
| 333 | + ) |
| 334 | + ) |
| 335 | + return after_event.result |
303 | 336 |
|
304 | | - result = yield from tool_func.stream(tool_use, **kwargs) |
305 | | - return result |
| 337 | + result = yield from selected_tool.stream(tool_use, **kwargs) |
| 338 | + after_event = get_registry(agent).invoke_callbacks( |
| 339 | + AfterToolInvocationEvent( |
| 340 | + agent=agent, |
| 341 | + selected_tool=selected_tool, |
| 342 | + tool_use=tool_use, |
| 343 | + kwargs=kwargs, |
| 344 | + result=result, |
| 345 | + ) |
| 346 | + ) |
| 347 | + return after_event.result |
306 | 348 |
|
307 | 349 | except Exception as e: |
308 | 350 | logger.exception("tool_name=<%s> | failed to process tool", tool_name) |
309 | | - return { |
310 | | - "toolUseId": tool_use_id, |
| 351 | + error_result: ToolResult = { |
| 352 | + "toolUseId": str(tool_use.get("toolUseId")), |
311 | 353 | "status": "error", |
312 | 354 | "content": [{"text": f"Error: {str(e)}"}], |
313 | 355 | } |
| 356 | + after_event = get_registry(agent).invoke_callbacks( |
| 357 | + AfterToolInvocationEvent( |
| 358 | + agent=agent, |
| 359 | + selected_tool=selected_tool, |
| 360 | + tool_use=tool_use, |
| 361 | + kwargs=kwargs, |
| 362 | + result=error_result, |
| 363 | + exception=e, |
| 364 | + ) |
| 365 | + ) |
| 366 | + return after_event.result |
314 | 367 |
|
315 | 368 |
|
316 | 369 | async def _handle_tool_execution( |
|
0 commit comments