diff --git a/docs/deferred-tools.md b/docs/deferred-tools.md
index b5dcb38685..4bc487b858 100644
--- a/docs/deferred-tools.md
+++ b/docs/deferred-tools.md
@@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
@agent.tool
def update_file(ctx: RunContext, path: str, content: str) -> str:
if path in PROTECTED_FILES and not ctx.tool_call_approved:
- raise ApprovalRequired
+ raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
return f'File {path!r} updated: {content!r}'
@@ -77,6 +77,7 @@ DeferredToolRequests(
tool_call_id='delete_file',
),
],
+ metadata={'update_file_dotenv': {'reason': 'protected'}},
)
"""
@@ -175,6 +176,8 @@ print(result.all_messages())
"""
```
+1. The `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
+
_(This example is complete, it can be run "as is")_
## External Tool Execution
@@ -209,13 +212,13 @@ from pydantic_ai import (
@dataclass
class TaskResult:
- tool_call_id: str
+ task_id: str
result: Any
-async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
+async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
await asyncio.sleep(1)
- return TaskResult(tool_call_id=tool_call_id, result=42)
+ return TaskResult(task_id=task_id, result=42)
agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
@@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []
@agent.tool
async def calculate_answer(ctx: RunContext, question: str) -> str:
- assert ctx.tool_call_id is not None
-
- task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
+ task_id = f'task_{len(tasks)}' # (1)!
+ task = asyncio.create_task(calculate_answer_task(task_id, question))
tasks.append(task)
- raise CallDeferred
+ raise CallDeferred(metadata={'task_id': task_id}) # (2)!
async def main():
@@ -252,17 +254,19 @@ async def main():
)
],
approvals=[],
+ metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
)
"""
- done, _ = await asyncio.wait(tasks) # (2)!
+ done, _ = await asyncio.wait(tasks) # (3)!
task_results = [task.result() for task in done]
- task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
+ task_results_by_task_id = {result.task_id: result.result for result in task_results}
results = DeferredToolResults()
for call in requests.calls:
try:
- result = task_results_by_tool_call_id[call.tool_call_id]
+ task_id = requests.metadata[call.tool_call_id]['task_id']
+ result = task_results_by_task_id[task_id]
except KeyError:
result = ModelRetry('No result for this tool call was found.')
@@ -324,8 +328,9 @@ async def main():
"""
```
-1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
-2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
+1. Generate a task ID that can be tracked independently of the tool call ID.
+2. The `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
+3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_
diff --git a/docs/toolsets.md b/docs/toolsets.md
index 8d970b8e31..1b041b3baa 100644
--- a/docs/toolsets.md
+++ b/docs/toolsets.md
@@ -362,6 +362,7 @@ DeferredToolRequests(
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
),
],
+ metadata={},
)
"""
diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py
index 91cda373a5..b1a0dd1350 100644
--- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py
+++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py
@@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]
deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
+ deferred_metadata: dict[str, dict[str, Any]] = {}
if calls_to_run:
async for event in _call_tools(
@@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901
usage_limits=ctx.deps.usage_limits,
output_parts=output_parts,
output_deferred_calls=deferred_calls,
+ output_deferred_metadata=deferred_metadata,
):
yield event
@@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901
deferred_tool_requests = _output.DeferredToolRequests(
calls=deferred_calls['external'],
approvals=deferred_calls['unapproved'],
+ metadata=deferred_metadata,
)
final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
@@ -949,10 +952,12 @@ async def _call_tools(
usage_limits: _usage.UsageLimits,
output_parts: list[_messages.ModelRequestPart],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
+ output_deferred_metadata: dict[str, dict[str, Any]],
) -> AsyncIterator[_messages.HandleResponseEvent]:
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
+ deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}
if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
@@ -987,10 +992,12 @@ async def handle_call_or_result(
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
- except exceptions.CallDeferred:
+ except exceptions.CallDeferred as e:
deferred_calls_by_index[index] = 'external'
- except exceptions.ApprovalRequired:
+ deferred_metadata_by_index[index] = e.metadata
+ except exceptions.ApprovalRequired as e:
deferred_calls_by_index[index] = 'unapproved'
+ deferred_metadata_by_index[index] = e.metadata
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
@@ -1028,8 +1035,25 @@ async def handle_call_or_result(
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])
+ _populate_deferred_calls(
+ tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
+ )
+
+
+def _populate_deferred_calls(
+ tool_calls: list[_messages.ToolCallPart],
+ deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
+ deferred_metadata_by_index: dict[int, dict[str, Any] | None],
+ output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
+ output_deferred_metadata: dict[str, dict[str, Any]],
+) -> None:
+ """Populate deferred calls and metadata from indexed mappings."""
for k in sorted(deferred_calls_by_index):
- output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
+ call = tool_calls[k]
+ output_deferred_calls[deferred_calls_by_index[k]].append(call)
+ metadata = deferred_metadata_by_index[k]
+ if metadata is not None:
+ output_deferred_metadata[call.tool_call_id] = metadata
async def _call_tool(
diff --git a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
index c52a34520e..850f001a46 100644
--- a/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
+++ b/pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
@@ -27,11 +27,13 @@ class CallToolParams:
@dataclass
class _ApprovalRequired:
+ metadata: dict[str, Any] | None = None
kind: Literal['approval_required'] = 'approval_required'
@dataclass
class _CallDeferred:
+ metadata: dict[str, Any] | None = None
kind: Literal['call_deferred'] = 'call_deferred'
@@ -75,10 +77,10 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
try:
result = await coro
return _ToolReturn(result=result)
- except ApprovalRequired:
- return _ApprovalRequired()
- except CallDeferred:
- return _CallDeferred()
+ except ApprovalRequired as e:
+ return _ApprovalRequired(metadata=e.metadata)
+ except CallDeferred as e:
+ return _CallDeferred(metadata=e.metadata)
except ModelRetry as e:
return _ModelRetry(message=e.message)
@@ -86,9 +88,9 @@ def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
if isinstance(result, _ToolReturn):
return result.result
elif isinstance(result, _ApprovalRequired):
- raise ApprovalRequired()
+ raise ApprovalRequired(metadata=result.metadata)
elif isinstance(result, _CallDeferred):
- raise CallDeferred()
+ raise CallDeferred(metadata=result.metadata)
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
else:
diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py
index da7ed89891..afeb8c524f 100644
--- a/pydantic_ai_slim/pydantic_ai/exceptions.py
+++ b/pydantic_ai_slim/pydantic_ai/exceptions.py
@@ -70,18 +70,30 @@ class CallDeferred(Exception):
"""Exception to raise when a tool call should be deferred.
See [tools docs](../deferred-tools.md#deferred-tools) for more information.
+
+ Args:
+ metadata: Optional dictionary of metadata to attach to the deferred tool call.
+ This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""
- pass
+ def __init__(self, metadata: dict[str, Any] | None = None):
+ self.metadata = metadata
+ super().__init__()
class ApprovalRequired(Exception):
"""Exception to raise when a tool call requires human-in-the-loop approval.
See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.
+
+ Args:
+ metadata: Optional dictionary of metadata to attach to the deferred tool call.
+ This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""
- pass
+ def __init__(self, metadata: dict[str, Any] | None = None):
+ self.metadata = metadata
+ super().__init__()
class UserError(RuntimeError):
diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py
index da053a5191..ca72cafbb5 100644
--- a/pydantic_ai_slim/pydantic_ai/tools.py
+++ b/pydantic_ai_slim/pydantic_ai/tools.py
@@ -147,6 +147,8 @@ class DeferredToolRequests:
"""Tool calls that require external execution."""
approvals: list[ToolCallPart] = field(default_factory=list)
"""Tool calls that require human-in-the-loop approval."""
+ metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
+ """Metadata for deferred tool calls, keyed by `tool_call_id`."""
@dataclass(kw_only=True)
diff --git a/tests/test_examples.py b/tests/test_examples.py
index 85bae688d0..51eb5e341d 100644
--- a/tests/test_examples.py
+++ b/tests/test_examples.py
@@ -874,10 +874,14 @@ async def model_logic( # noqa: C901
return ModelResponse(
parts=[TextPart('The answer to the ultimate question of life, the universe, and everything is 42.')]
)
- else:
+ if isinstance(m, ToolReturnPart):
sys.stdout.write(str(debug.format(messages, info)))
raise RuntimeError(f'Unexpected message: {m}')
+ # Fallback for any other message type
+ sys.stdout.write(str(debug.format(messages, info)))
+ raise RuntimeError(f'Unexpected message type: {type(m).__name__}')
+
async def stream_model_logic( # noqa C901
messages: list[ModelMessage], info: AgentInfo
diff --git a/tests/test_streaming.py b/tests/test_streaming.py
index 230a19501a..0c6a46f3c0 100644
--- a/tests/test_streaming.py
+++ b/tests/test_streaming.py
@@ -1712,9 +1712,7 @@ def my_tool(x: int) -> int:
[DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])]
)
assert await result.get_output() == snapshot(
- DeferredToolRequests(
- calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
- )
+ DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
)
responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)]
assert responses == snapshot(
@@ -1757,9 +1755,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
messages = result.all_messages()
output = await result.get_output()
assert output == snapshot(
- DeferredToolRequests(
- approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())],
- )
+ DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())])
)
assert result.is_complete
diff --git a/tests/test_tools.py b/tests/test_tools.py
index 3b30056c3a..f65105b4e6 100644
--- a/tests/test_tools.py
+++ b/tests/test_tools.py
@@ -1318,9 +1318,7 @@ def my_tool(x: int) -> int:
result = agent.run_sync('Hello')
assert result.output == snapshot(
- DeferredToolRequests(
- calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
- )
+ DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
)
@@ -1398,6 +1396,184 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
assert result.output == snapshot('Done!')
+def test_call_deferred_with_metadata():
+ """Test that CallDeferred exception can carry metadata."""
+ agent = Agent(TestModel(), output_type=[str, DeferredToolRequests])
+
+ @agent.tool_plain
+ def my_tool(x: int) -> int:
+ raise CallDeferred(metadata={'task_id': 'task-123', 'estimated_cost': 25.50})
+
+ result = agent.run_sync('Hello')
+ assert result.output == snapshot(
+ DeferredToolRequests(
+ calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
+ metadata={'pyd_ai_tool_call_id__my_tool': {'task_id': 'task-123', 'estimated_cost': 25.5}},
+ )
+ )
+
+
+def test_approval_required_with_metadata():
+ """Test that ApprovalRequired exception can carry metadata."""
+
+ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
+ if len(messages) == 1:
+ return ModelResponse(
+ parts=[
+ ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'),
+ ]
+ )
+ else:
+ return ModelResponse(
+ parts=[
+ TextPart('Done!'),
+ ]
+ )
+
+ agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests])
+
+ @agent.tool
+ def my_tool(ctx: RunContext[None], x: int) -> int:
+ if not ctx.tool_call_approved:
+ raise ApprovalRequired(
+ metadata={
+ 'reason': 'High compute cost',
+ 'estimated_time': '5 minutes',
+ 'cost_usd': 100.0,
+ }
+ )
+ return x * 42
+
+ result = agent.run_sync('Hello')
+ assert result.output == snapshot(
+ DeferredToolRequests(
+ approvals=[ToolCallPart(tool_name='my_tool', args={'x': 1}, tool_call_id=IsStr())],
+ metadata={'my_tool': {'reason': 'High compute cost', 'estimated_time': '5 minutes', 'cost_usd': 100.0}},
+ )
+ )
+
+ # Continue with approval
+ messages = result.all_messages()
+ result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}),
+ )
+ assert result.output == 'Done!'
+
+
+def test_call_deferred_without_metadata():
+ """Test backward compatibility: CallDeferred without metadata still works."""
+ agent = Agent(TestModel(), output_type=[str, DeferredToolRequests])
+
+ @agent.tool_plain
+ def my_tool(x: int) -> int:
+ raise CallDeferred # No metadata
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.calls) == 1
+
+ tool_call_id = result.output.calls[0].tool_call_id
+ # Should have an empty metadata dict for this tool
+ assert result.output.metadata.get(tool_call_id, {}) == {}
+
+
+def test_approval_required_without_metadata():
+ """Test backward compatibility: ApprovalRequired without metadata still works."""
+
+ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
+ if len(messages) == 1:
+ return ModelResponse(
+ parts=[
+ ToolCallPart('my_tool', {'x': 1}, tool_call_id='my_tool'),
+ ]
+ )
+ else:
+ return ModelResponse(
+ parts=[
+ TextPart('Done!'),
+ ]
+ )
+
+ agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests])
+
+ @agent.tool
+ def my_tool(ctx: RunContext[None], x: int) -> int:
+ if not ctx.tool_call_approved:
+ raise ApprovalRequired # No metadata
+ return x * 42
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+ assert len(result.output.approvals) == 1
+
+ # Should have an empty metadata dict for this tool
+ assert result.output.metadata.get('my_tool', {}) == {}
+
+ # Continue with approval
+ messages = result.all_messages()
+ result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=DeferredToolResults(approvals={'my_tool': ToolApproved()}),
+ )
+ assert result.output == 'Done!'
+
+
+def test_mixed_deferred_tools_with_metadata():
+ """Test multiple deferred tools with different metadata."""
+
+ def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse:
+ if len(messages) == 1:
+ return ModelResponse(
+ parts=[
+ ToolCallPart('tool_a', {'x': 1}, tool_call_id='call_a'),
+ ToolCallPart('tool_b', {'y': 2}, tool_call_id='call_b'),
+ ToolCallPart('tool_c', {'z': 3}, tool_call_id='call_c'),
+ ]
+ )
+ else:
+ return ModelResponse(parts=[TextPart('Done!')])
+
+ agent = Agent(FunctionModel(llm), output_type=[str, DeferredToolRequests])
+
+ @agent.tool
+ def tool_a(ctx: RunContext[None], x: int) -> int:
+ raise CallDeferred(metadata={'type': 'external', 'priority': 'high'})
+
+ @agent.tool
+ def tool_b(ctx: RunContext[None], y: int) -> int:
+ if not ctx.tool_call_approved:
+ raise ApprovalRequired(metadata={'reason': 'Needs approval', 'level': 'manager'})
+ return y * 10
+
+ @agent.tool
+ def tool_c(ctx: RunContext[None], z: int) -> int:
+ raise CallDeferred # No metadata
+
+ result = agent.run_sync('Hello')
+ assert isinstance(result.output, DeferredToolRequests)
+
+ # Check that we have the right tools deferred
+ assert len(result.output.calls) == 2 # tool_a and tool_c
+ assert len(result.output.approvals) == 1 # tool_b
+
+ # Check metadata
+ assert result.output.metadata['call_a'] == {'type': 'external', 'priority': 'high'}
+ assert result.output.metadata['call_b'] == {'reason': 'Needs approval', 'level': 'manager'}
+ assert result.output.metadata.get('call_c', {}) == {}
+
+ # Continue with results for all three tools
+ messages = result.all_messages()
+ result = agent.run_sync(
+ message_history=messages,
+ deferred_tool_results=DeferredToolResults(
+ calls={'call_a': 10, 'call_c': 30},
+ approvals={'call_b': ToolApproved()},
+ ),
+ )
+ assert result.output == 'Done!'
+
+
def test_deferred_tool_with_output_type():
class MyModel(BaseModel):
foo: str
@@ -1590,7 +1766,7 @@ def buy(fruit: str):
ToolCallPart(tool_name='buy', args={'fruit': 'apple'}, tool_call_id='buy_apple'),
ToolCallPart(tool_name='buy', args={'fruit': 'banana'}, tool_call_id='buy_banana'),
ToolCallPart(tool_name='buy', args={'fruit': 'pear'}, tool_call_id='buy_pear'),
- ],
+ ]
)
)
diff --git a/tests/test_ui.py b/tests/test_ui.py
index 38f9950ad5..a497d09389 100644
--- a/tests/test_ui.py
+++ b/tests/test_ui.py
@@ -439,7 +439,7 @@ async def test_run_stream_external_tools():
'',
"{}",
'',
- "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[])",
+ "DeferredToolRequests(calls=[ToolCallPart(tool_name='external_tool', args={}, tool_call_id='pyd_ai_tool_call_id__external_tool')], approvals=[], metadata={})",
'',
]
)