diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md index b5dcb38685..4bc487b858 100644 --- a/docs/deferred-tools.md +++ b/docs/deferred-tools.md @@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'} @agent.tool def update_file(ctx: RunContext, path: str, content: str) -> str: if path in PROTECTED_FILES and not ctx.tool_call_approved: - raise ApprovalRequired + raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)! return f'File {path!r} updated: {content!r}' @@ -77,6 +77,7 @@ DeferredToolRequests( tool_call_id='delete_file', ), ], + metadata={'update_file_dotenv': {'reason': 'protected'}}, ) """ @@ -175,6 +176,8 @@ print(result.all_messages()) """ ``` +1. The `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`. + _(This example is complete, it can be run "as is")_ ## External Tool Execution @@ -209,13 +212,13 @@ from pydantic_ai import ( @dataclass class TaskResult: - tool_call_id: str + task_id: str result: Any -async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult: +async def calculate_answer_task(task_id: str, question: str) -> TaskResult: await asyncio.sleep(1) - return TaskResult(tool_call_id=tool_call_id, result=42) + return TaskResult(task_id=task_id, result=42) agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests]) @@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = [] @agent.tool async def calculate_answer(ctx: RunContext, question: str) -> str: - assert ctx.tool_call_id is not None - - task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)! + task_id = f'task_{len(tasks)}' # (1)! + task = asyncio.create_task(calculate_answer_task(task_id, question)) tasks.append(task) - raise CallDeferred + raise CallDeferred(metadata={'task_id': task_id}) # (2)! async def main(): @@ -252,17 +254,19 @@ async def main(): ) ], approvals=[], + metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}}, ) """ - done, _ = await asyncio.wait(tasks) # (2)! + done, _ = await asyncio.wait(tasks) # (3)! task_results = [task.result() for task in done] - task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results} + task_results_by_task_id = {result.task_id: result.result for result in task_results} results = DeferredToolResults() for call in requests.calls: try: - result = task_results_by_tool_call_id[call.tool_call_id] + task_id = requests.metadata[call.tool_call_id]['task_id'] + result = task_results_by_task_id[task_id] except KeyError: result = ModelRetry('No result for this tool call was found.') @@ -324,8 +328,9 @@ async def main(): """ ``` -1. In reality, you'd likely use Celery or a similar task queue to run the task in the background. -2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete. +1. Generate a task ID that can be tracked independently of the tool call ID. +2. The `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`. +3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete. _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ diff --git a/docs/toolsets.md b/docs/toolsets.md index 8d970b8e31..1b041b3baa 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -362,6 +362,7 @@ DeferredToolRequests( tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit', ), ], + metadata={}, ) """ diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 91cda373a5..b1a0dd1350 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901 calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results] deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list) + deferred_metadata: dict[str, dict[str, Any]] = {} if calls_to_run: async for event in _call_tools( @@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901 usage_limits=ctx.deps.usage_limits, output_parts=output_parts, output_deferred_calls=deferred_calls, + output_deferred_metadata=deferred_metadata, ): yield event @@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901 deferred_tool_requests = _output.DeferredToolRequests( calls=deferred_calls['external'], approvals=deferred_calls['unapproved'], + metadata=deferred_metadata, ) final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None) @@ -949,10 +952,12 @@ async def _call_tools( usage_limits: _usage.UsageLimits, output_parts: list[_messages.ModelRequestPart], output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]], + output_deferred_metadata: dict[str, dict[str, Any]], ) -> AsyncIterator[_messages.HandleResponseEvent]: tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} user_parts_by_index: dict[int, _messages.UserPromptPart] = {} deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {} + deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {} if usage_limits.tool_calls_limit is not None: projected_usage = deepcopy(usage) @@ -987,10 +992,12 @@ async def handle_call_or_result( tool_part, tool_user_content = ( (await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result() ) - except exceptions.CallDeferred: + except exceptions.CallDeferred as e: deferred_calls_by_index[index] = 'external' - except exceptions.ApprovalRequired: + deferred_metadata_by_index[index] = e.metadata + except exceptions.ApprovalRequired as e: deferred_calls_by_index[index] = 'unapproved' + deferred_metadata_by_index[index] = e.metadata else: tool_parts_by_index[index] = tool_part if tool_user_content: @@ -1028,8 +1035,25 @@ async def handle_call_or_result( output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)]) output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)]) + _populate_deferred_calls( + tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata + ) + + +def _populate_deferred_calls( + tool_calls: list[_messages.ToolCallPart], + deferred_calls_by_index: dict[int, Literal['external', 'unapproved']], + deferred_metadata_by_index: dict[int, dict[str, Any] | None], + output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]], + output_deferred_metadata: dict[str, dict[str, Any]], +) -> None: + """Populate deferred calls and metadata from indexed mappings.""" for k in sorted(deferred_calls_by_index): - output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k]) + call = tool_calls[k] + output_deferred_calls[deferred_calls_by_index[k]].append(call) + metadata = deferred_metadata_by_index[k] + if metadata is not None: + output_deferred_metadata[call.tool_call_id] = metadata async def _call_tool( diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py index c52a34520e..850f001a46 100644 --- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py +++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py @@ -27,11 +27,13 @@ class CallToolParams: @dataclass class _ApprovalRequired: + metadata: dict[str, Any] | None = None kind: Literal['approval_required'] = 'approval_required' @dataclass class _CallDeferred: + metadata: dict[str, Any] | None = None kind: Literal['call_deferred'] = 'call_deferred' @@ -75,10 +77,10 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult: try: result = await coro return _ToolReturn(result=result) - except ApprovalRequired: - return _ApprovalRequired() - except CallDeferred: - return _CallDeferred() + except ApprovalRequired as e: + return _ApprovalRequired(metadata=e.metadata) + except CallDeferred as e: + return _CallDeferred(metadata=e.metadata) except ModelRetry as e: return _ModelRetry(message=e.message) @@ -86,9 +88,9 @@ def _unwrap_call_tool_result(self, result: CallToolResult) -> Any: if isinstance(result, _ToolReturn): return result.result elif isinstance(result, _ApprovalRequired): - raise ApprovalRequired() + raise ApprovalRequired(metadata=result.metadata) elif isinstance(result, _CallDeferred): - raise CallDeferred() + raise CallDeferred(metadata=result.metadata) elif isinstance(result, _ModelRetry): raise ModelRetry(result.message) else: diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index da7ed89891..afeb8c524f 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -70,18 +70,30 @@ class CallDeferred(Exception): """Exception to raise when a tool call should be deferred. See [tools docs](../deferred-tools.md#deferred-tools) for more information. + + Args: + metadata: Optional dictionary of metadata to attach to the deferred tool call. + This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`. """ - pass + def __init__(self, metadata: dict[str, Any] | None = None): + self.metadata = metadata + super().__init__() class ApprovalRequired(Exception): """Exception to raise when a tool call requires human-in-the-loop approval. See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information. + + Args: + metadata: Optional dictionary of metadata to attach to the deferred tool call. + This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`. """ - pass + def __init__(self, metadata: dict[str, Any] | None = None): + self.metadata = metadata + super().__init__() class UserError(RuntimeError): diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index da053a5191..ca72cafbb5 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -147,6 +147,8 @@ class DeferredToolRequests: """Tool calls that require external execution.""" approvals: list[ToolCallPart] = field(default_factory=list) """Tool calls that require human-in-the-loop approval.""" + metadata: dict[str, dict[str, Any]] = field(default_factory=dict) + """Metadata for deferred tool calls, keyed by `tool_call_id`.""" @dataclass(kw_only=True) diff --git a/tests/test_examples.py b/tests/test_examples.py index 85bae688d0..51eb5e341d 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -874,10 +874,14 @@ async def model_logic( # noqa: C901 return ModelResponse( parts=[TextPart('The answer to the ultimate question of life, the universe, and everything is 42.')] ) - else: + if isinstance(m, ToolReturnPart): sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') + # Fallback for any other message type + sys.stdout.write(str(debug.format(messages, info))) + raise RuntimeError(f'Unexpected message type: {type(m).__name__}') + async def stream_model_logic( # noqa C901 messages: list[ModelMessage], info: AgentInfo diff --git a/tests/test_streaming.py b/tests/test_streaming.py index 230a19501a..0c6a46f3c0 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1712,9 +1712,7 @@ def my_tool(x: int) -> int: [DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])] ) assert await result.get_output() == snapshot( - DeferredToolRequests( - calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], - ) + DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())]) ) responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)] assert responses == snapshot( @@ -1757,9 +1755,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int: messages = result.all_messages() output = await result.get_output() assert output == snapshot( - DeferredToolRequests( - approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())], - ) + DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())]) ) assert result.is_complete diff --git a/tests/test_tools.py b/tests/test_tools.py index 3b30056c3a..f65105b4e6 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -1318,9 +1318,7 @@ def my_tool(x: int) -> int: result = agent.run_sync('Hello') assert result.output == snapshot( - DeferredToolRequests( - calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], - ) + DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())]) ) @@ -1398,6 +1396,184 @@ def my_tool(ctx: RunContext[None], x: int) -> int: assert result.output == snapshot('Done!') +def test_call_deferred_with_metadata(): + """Test that CallDeferred exception can carry metadata.""" + agent = Agent(TestModel(), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain + def my_tool(x: int) -> int: + raise CallDeferred(metadata={'task_id': 'task-123', 'estimated_cost': 25.50}) + + result = agent.run_sync('Hello') + assert result.output == snapshot( + DeferredToolRequests( + calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + metadata={'pyd_ai_tool_call_id__my_tool': {'task_id': 'task-123', 'estimated_cost': 25.5}}, + ) + ) + + +def test_approval_required_with_metadata(): + """Test that ApprovalRequired exception can carry metadata.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'), + ] + ) + else: + return ModelResponse( + parts=[ + TextPart('Done!'), + ] + ) + + agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) + + @agent.tool + def my_tool(ctx: RunContext[None], x: int) -> int: + if not ctx.tool_call_approved: + raise ApprovalRequired( + metadata={ + 'reason': 'High compute cost', + 'estimated_time': '5 minutes', + 'cost_usd': 100.0, + } + ) + return x * 42 + + result = agent.run_sync('Hello') + assert result.output == snapshot( + DeferredToolRequests( + approvals=[ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr())], + metadata={'my_tool': {'reason': 'High compute cost', 'estimated_time': '5 minutes', 'cost_usd': 100.0}}, + ) + ) + + # Continue with approval + messages = result.all_messages() + result = agent.run_sync( + message_history=messages, + deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}), + ) + assert result.output == 'Done!' + + +def test_call_deferred_without_metadata(): + """Test backward compatibility: CallDeferred without metadata still works.""" + agent = Agent(TestModel(), output_type=[str, DeferredToolRequests]) + + @agent.tool_plain + def my_tool(x: int) -> int: + raise CallDeferred # No metadata + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.calls) == 1 + + tool_call_id = result.output.calls[0].tool_call_id + # Should have an empty metadata dict for this tool + assert result.output.metadata.get(tool_call_id, {}) == {} + + +def test_approval_required_without_metadata(): + """Test backward compatibility: ApprovalRequired without metadata still works.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'), + ] + ) + else: + return ModelResponse( + parts=[ + TextPart('Done!'), + ] + ) + + agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) + + @agent.tool + def my_tool(ctx: RunContext[None], x: int) -> int: + if not ctx.tool_call_approved: + raise ApprovalRequired # No metadata + return x * 42 + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + assert len(result.output.approvals) == 1 + + # Should have an empty metadata dict for this tool + assert result.output.metadata.get('my_tool', {}) == {} + + # Continue with approval + messages = result.all_messages() + result = agent.run_sync( + message_history=messages, + deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}), + ) + assert result.output == 'Done!' + + +def test_mixed_deferred_tools_with_metadata(): + """Test multiple deferred tools with different metadata.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ + ToolCallPart('tool_a', {'x': 1}, tool_call_id='call_a'), + ToolCallPart('tool_b', {'y': 2}, tool_call_id='call_b'), + ToolCallPart('tool_c', {'z': 3}, tool_call_id='call_c'), + ] + ) + else: + return ModelResponse(parts=[TextPart('Done!')]) + + agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests]) + + @agent.tool + def tool_a(ctx: RunContext[None], x: int) -> int: + raise CallDeferred(metadata={'type': 'external', 'priority': 'high'}) + + @agent.tool + def tool_b(ctx: RunContext[None], y: int) -> int: + if not ctx.tool_call_approved: + raise ApprovalRequired(metadata={'reason': 'Needs approval', 'level': 'manager'}) + return y * 10 + + @agent.tool + def tool_c(ctx: RunContext[None], z: int) -> int: + raise CallDeferred # No metadata + + result = agent.run_sync('Hello') + assert isinstance(result.output, DeferredToolRequests) + + # Check that we have the right tools deferred + assert len(result.output.calls) == 2 # tool_a and tool_c + assert len(result.output.approvals) == 1 # tool_b + + # Check metadata + assert result.output.metadata['call_a'] == {'type': 'external', 'priority': 'high'} + assert result.output.metadata['call_b'] == {'reason': 'Needs approval', 'level': 'manager'} + assert result.output.metadata.get('call_c', {}) == {} + + # Continue with results for all three tools + messages = result.all_messages() + result = agent.run_sync( + message_history=messages, + deferred_tool_results=DeferredToolResults( + calls={'call_a': 10, 'call_c': 30}, + approvals={'call_b': ToolApproved()}, + ), + ) + assert result.output == 'Done!' + + def test_deferred_tool_with_output_type(): class MyModel(BaseModel): foo: str @@ -1590,7 +1766,7 @@ def buy(fruit: str): ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'), ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'), ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'), - ], + ] ) ) diff --git a/tests/test_ui.py b/tests/test_ui.py index 38f9950ad5..a497d09389 100644 --- a/tests/test_ui.py +++ b/tests/test_ui.py @@ -439,7 +439,7 @@ async def test_run_stream_external_tools(): '', "{}", '', - "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[])", + "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[], metadata={})", '', ] )