Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
17 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions docs/deferred-tools.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ PROTECTED_FILES = {'.env'}
@agent.tool
def update_file(ctx: RunContext, path: str, content: str) -> str:
if path in PROTECTED_FILES and not ctx.tool_call_approved:
raise ApprovalRequired
raise ApprovalRequired(metadata={'reason': 'protected'}) # (1)!
return f'File {path!r} updated: {content!r}'


Expand Down Expand Up @@ -77,6 +77,7 @@ DeferredToolRequests(
tool_call_id='delete_file',
),
],
metadata={'update_file_dotenv': {'reason': 'protected'}},
)
"""

Expand Down Expand Up @@ -175,6 +176,8 @@ print(result.all_messages())
"""
```

1. The `metadata` parameter can attach arbitrary context to deferred tool calls, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.

_(This example is complete, it can be run "as is")_

## External Tool Execution
Expand Down Expand Up @@ -209,13 +212,13 @@ from pydantic_ai import (

@dataclass
class TaskResult:
tool_call_id: str
task_id: str
result: Any


async def calculate_answer_task(tool_call_id: str, question: str) -> TaskResult:
async def calculate_answer_task(task_id: str, question: str) -> TaskResult:
await asyncio.sleep(1)
return TaskResult(tool_call_id=tool_call_id, result=42)
return TaskResult(task_id=task_id, result=42)


agent = Agent('openai:gpt-5', output_type=[str, DeferredToolRequests])
Expand All @@ -225,12 +228,11 @@ tasks: list[asyncio.Task[TaskResult]] = []

@agent.tool
async def calculate_answer(ctx: RunContext, question: str) -> str:
assert ctx.tool_call_id is not None

task = asyncio.create_task(calculate_answer_task(ctx.tool_call_id, question)) # (1)!
task_id = f'task_{len(tasks)}' # (1)!
task = asyncio.create_task(calculate_answer_task(task_id, question))
tasks.append(task)

raise CallDeferred
raise CallDeferred(metadata={'task_id': task_id}) # (2)!


async def main():
Expand All @@ -252,17 +254,19 @@ async def main():
)
],
approvals=[],
metadata={'pyd_ai_tool_call_id': {'task_id': 'task_0'}},
)
"""

done, _ = await asyncio.wait(tasks) # (2)!
done, _ = await asyncio.wait(tasks) # (3)!
task_results = [task.result() for task in done]
task_results_by_tool_call_id = {result.tool_call_id: result.result for result in task_results}
task_results_by_task_id = {result.task_id: result.result for result in task_results}

results = DeferredToolResults()
for call in requests.calls:
try:
result = task_results_by_tool_call_id[call.tool_call_id]
task_id = requests.metadata[call.tool_call_id]['task_id']
result = task_results_by_task_id[task_id]
except KeyError:
result = ModelRetry('No result for this tool call was found.')

Expand Down Expand Up @@ -324,8 +328,9 @@ async def main():
"""
```

1. In reality, you'd likely use Celery or a similar task queue to run the task in the background.
2. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.
1. Generate a task ID that can be tracked independently of the tool call ID.
2. The `metadata` parameter passes the `task_id` so it can be matched with results later, accessible in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
3. In reality, this would typically happen in a separate process that polls for the task status or is notified when all pending tasks are complete.

_(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_

Expand Down
1 change: 1 addition & 0 deletions docs/toolsets.md
Original file line number Diff line number Diff line change
Expand Up @@ -362,6 +362,7 @@ DeferredToolRequests(
tool_call_id='pyd_ai_tool_call_id__temperature_fahrenheit',
),
],
metadata={},
)
"""

Expand Down
30 changes: 27 additions & 3 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,7 @@ async def process_tool_calls( # noqa: C901
calls_to_run = [call for call in calls_to_run if call.tool_call_id in calls_to_run_results]

deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]] = defaultdict(list)
deferred_metadata: dict[str, dict[str, Any]] = {}

if calls_to_run:
async for event in _call_tools(
Expand All @@ -899,6 +900,7 @@ async def process_tool_calls( # noqa: C901
usage_limits=ctx.deps.usage_limits,
output_parts=output_parts,
output_deferred_calls=deferred_calls,
output_deferred_metadata=deferred_metadata,
):
yield event

Expand Down Expand Up @@ -932,6 +934,7 @@ async def process_tool_calls( # noqa: C901
deferred_tool_requests = _output.DeferredToolRequests(
calls=deferred_calls['external'],
approvals=deferred_calls['unapproved'],
metadata=deferred_metadata,
)

final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_requests), None, None)
Expand All @@ -949,10 +952,12 @@ async def _call_tools(
usage_limits: _usage.UsageLimits,
output_parts: list[_messages.ModelRequestPart],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
output_deferred_metadata: dict[str, dict[str, Any]],
) -> AsyncIterator[_messages.HandleResponseEvent]:
tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {}
user_parts_by_index: dict[int, _messages.UserPromptPart] = {}
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']] = {}
deferred_metadata_by_index: dict[int, dict[str, Any] | None] = {}

if usage_limits.tool_calls_limit is not None:
projected_usage = deepcopy(usage)
Expand Down Expand Up @@ -987,10 +992,12 @@ async def handle_call_or_result(
tool_part, tool_user_content = (
(await coro_or_task) if inspect.isawaitable(coro_or_task) else coro_or_task.result()
)
except exceptions.CallDeferred:
except exceptions.CallDeferred as e:
deferred_calls_by_index[index] = 'external'
except exceptions.ApprovalRequired:
deferred_metadata_by_index[index] = e.metadata
except exceptions.ApprovalRequired as e:
deferred_calls_by_index[index] = 'unapproved'
deferred_metadata_by_index[index] = e.metadata
else:
tool_parts_by_index[index] = tool_part
if tool_user_content:
Expand Down Expand Up @@ -1028,8 +1035,25 @@ async def handle_call_or_result(
output_parts.extend([tool_parts_by_index[k] for k in sorted(tool_parts_by_index)])
output_parts.extend([user_parts_by_index[k] for k in sorted(user_parts_by_index)])

_populate_deferred_calls(
tool_calls, deferred_calls_by_index, deferred_metadata_by_index, output_deferred_calls, output_deferred_metadata
)


def _populate_deferred_calls(
tool_calls: list[_messages.ToolCallPart],
deferred_calls_by_index: dict[int, Literal['external', 'unapproved']],
deferred_metadata_by_index: dict[int, dict[str, Any] | None],
output_deferred_calls: dict[Literal['external', 'unapproved'], list[_messages.ToolCallPart]],
output_deferred_metadata: dict[str, dict[str, Any]],
) -> None:
"""Populate deferred calls and metadata from indexed mappings."""
for k in sorted(deferred_calls_by_index):
output_deferred_calls[deferred_calls_by_index[k]].append(tool_calls[k])
call = tool_calls[k]
output_deferred_calls[deferred_calls_by_index[k]].append(call)
metadata = deferred_metadata_by_index[k]
if metadata is not None:
output_deferred_metadata[call.tool_call_id] = metadata


async def _call_tool(
Expand Down
14 changes: 8 additions & 6 deletions pydantic_ai_slim/pydantic_ai/durable_exec/temporal/_toolset.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,11 +27,13 @@ class CallToolParams:

@dataclass
class _ApprovalRequired:
metadata: dict[str, Any] | None = None
kind: Literal['approval_required'] = 'approval_required'


@dataclass
class _CallDeferred:
metadata: dict[str, Any] | None = None
kind: Literal['call_deferred'] = 'call_deferred'


Expand Down Expand Up @@ -75,20 +77,20 @@ async def _wrap_call_tool_result(self, coro: Awaitable[Any]) -> CallToolResult:
try:
result = await coro
return _ToolReturn(result=result)
except ApprovalRequired:
return _ApprovalRequired()
except CallDeferred:
return _CallDeferred()
except ApprovalRequired as e:
return _ApprovalRequired(metadata=e.metadata)
except CallDeferred as e:
return _CallDeferred(metadata=e.metadata)
except ModelRetry as e:
return _ModelRetry(message=e.message)

def _unwrap_call_tool_result(self, result: CallToolResult) -> Any:
if isinstance(result, _ToolReturn):
return result.result
elif isinstance(result, _ApprovalRequired):
raise ApprovalRequired()
raise ApprovalRequired(metadata=result.metadata)
elif isinstance(result, _CallDeferred):
raise CallDeferred()
raise CallDeferred(metadata=result.metadata)
elif isinstance(result, _ModelRetry):
raise ModelRetry(result.message)
else:
Expand Down
16 changes: 14 additions & 2 deletions pydantic_ai_slim/pydantic_ai/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,18 +70,30 @@ class CallDeferred(Exception):
"""Exception to raise when a tool call should be deferred.

See [tools docs](../deferred-tools.md#deferred-tools) for more information.

Args:
metadata: Optional dictionary of metadata to attach to the deferred tool call.
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""

pass
def __init__(self, metadata: dict[str, Any] | None = None):
self.metadata = metadata
super().__init__()


class ApprovalRequired(Exception):
"""Exception to raise when a tool call requires human-in-the-loop approval.

See [tools docs](../deferred-tools.md#human-in-the-loop-tool-approval) for more information.

Args:
metadata: Optional dictionary of metadata to attach to the deferred tool call.
This metadata will be available in `DeferredToolRequests.metadata` keyed by `tool_call_id`.
"""

pass
def __init__(self, metadata: dict[str, Any] | None = None):
self.metadata = metadata
super().__init__()


class UserError(RuntimeError):
Expand Down
2 changes: 2 additions & 0 deletions pydantic_ai_slim/pydantic_ai/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,8 @@ class DeferredToolRequests:
"""Tool calls that require external execution."""
approvals: list[ToolCallPart] = field(default_factory=list)
"""Tool calls that require human-in-the-loop approval."""
metadata: dict[str, dict[str, Any]] = field(default_factory=dict)
"""Metadata for deferred tool calls, keyed by `tool_call_id`."""


@dataclass(kw_only=True)
Expand Down
6 changes: 5 additions & 1 deletion tests/test_examples.py
Original file line number Diff line number Diff line change
Expand Up @@ -874,10 +874,14 @@ async def model_logic( # noqa: C901
return ModelResponse(
parts=[TextPart('The answer to the ultimate question of life, the universe, and everything is 42.')]
)
else:
if isinstance(m, ToolReturnPart):
sys.stdout.write(str(debug.format(messages, info)))
raise RuntimeError(f'Unexpected message: {m}')

# Fallback for any other message type
sys.stdout.write(str(debug.format(messages, info)))
raise RuntimeError(f'Unexpected message type: {type(m).__name__}')


async def stream_model_logic( # noqa C901
messages: list[ModelMessage], info: AgentInfo
Expand Down
8 changes: 2 additions & 6 deletions tests/test_streaming.py
Original file line number Diff line number Diff line change
Expand Up @@ -1712,9 +1712,7 @@ def my_tool(x: int) -> int:
[DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])]
)
assert await result.get_output() == snapshot(
DeferredToolRequests(
calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())],
)
DeferredToolRequests(calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())])
)
responses = [c async for c, _is_last in result.stream_responses(debounce_by=None)]
assert responses == snapshot(
Expand Down Expand Up @@ -1757,9 +1755,7 @@ def my_tool(ctx: RunContext[None], x: int) -> int:
messages = result.all_messages()
output = await result.get_output()
assert output == snapshot(
DeferredToolRequests(
approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())],
)
DeferredToolRequests(approvals=[ToolCallPart(tool_name='my_tool', args='{"x": 1}', tool_call_id=IsStr())])
)
assert result.is_complete

Expand Down
Loading