Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions pydantic_ai_slim/pydantic_ai/ui/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,6 +143,11 @@ def load_messages(cls, messages: Sequence[MessageT]) -> list[ModelMessage]:
"""Transform protocol-specific messages into Pydantic AI messages."""
raise NotImplementedError

@classmethod
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[MessageT]:
"""Transform Pydantic AI messages into protocol-specific messages."""
raise NotImplementedError

@abstractmethod
def build_event_stream(self) -> UIEventStream[RunInputT, EventT, AgentDepsT, OutputDataT]:
"""Build a protocol-specific event stream transformer."""
Expand Down
242 changes: 239 additions & 3 deletions pydantic_ai_slim/pydantic_ai/ui/vercel_ai/_adapter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,12 @@

from __future__ import annotations

import json
import uuid
from collections.abc import Sequence
from dataclasses import dataclass
from functools import cached_property
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, cast

from pydantic import TypeAdapter
from typing_extensions import assert_never
Expand All @@ -15,10 +17,13 @@
BinaryContent,
BuiltinToolCallPart,
BuiltinToolReturnPart,
CachePoint,
DocumentUrl,
FilePart,
ImageUrl,
ModelMessage,
ModelRequest,
ModelResponse,
RetryPromptPart,
SystemPromptPart,
TextPart,
Expand All @@ -35,6 +40,9 @@
from ._event_stream import VercelAIEventStream
from .request_types import (
DataUIPart,
DynamicToolInputAvailablePart,
DynamicToolOutputAvailablePart,
DynamicToolOutputErrorPart,
DynamicToolUIPart,
FileUIPart,
ReasoningUIPart,
Expand All @@ -43,10 +51,12 @@
SourceUrlUIPart,
StepStartUIPart,
TextUIPart,
ToolInputAvailablePart,
ToolOutputAvailablePart,
ToolOutputErrorPart,
ToolUIPart,
UIMessage,
UIMessagePart,
)
from .response_types import BaseChunk

Expand Down Expand Up @@ -122,7 +132,16 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
if isinstance(part, TextUIPart):
builder.add(TextPart(content=part.text))
elif isinstance(part, ReasoningUIPart):
builder.add(ThinkingPart(content=part.text))
pydantic_ai_meta = (part.provider_metadata or {}).get('pydantic_ai', {})
builder.add(
ThinkingPart(
content=part.text,
id=pydantic_ai_meta.get('id'),
signature=pydantic_ai_meta.get('signature'),
provider_name=pydantic_ai_meta.get('provider_name'),
provider_details=pydantic_ai_meta.get('provider_details'),
)
)
elif isinstance(part, FileUIPart):
try:
file = BinaryContent.from_data_uri(part.url)
Expand All @@ -141,7 +160,20 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
builtin_tool = part.provider_executed

tool_call_id = part.tool_call_id
args = part.input

args: str | dict[str, Any] | None = part.input

if isinstance(args, str):
try:
parsed = json.loads(args)
if isinstance(parsed, dict):
args = cast(dict[str, Any], parsed)
except json.JSONDecodeError:
pass
elif isinstance(args, dict) or args is None:
pass
else:
assert_never(args)

if builtin_tool:
call_part = BuiltinToolCallPart(tool_name=tool_name, tool_call_id=tool_call_id, args=args)
Expand Down Expand Up @@ -197,3 +229,207 @@ def load_messages(cls, messages: Sequence[UIMessage]) -> list[ModelMessage]: #
assert_never(msg.role)

return builder.messages

@staticmethod
def _dump_request_message(msg: ModelRequest) -> tuple[list[UIMessagePart], list[UIMessagePart]]:
"""Convert a ModelRequest into a UIMessage."""
system_ui_parts: list[UIMessagePart] = []
user_ui_parts: list[UIMessagePart] = []

for part in msg.parts:
if isinstance(part, SystemPromptPart):
system_ui_parts.append(TextUIPart(text=part.content, state='done'))
elif isinstance(part, UserPromptPart):
user_ui_parts.extend(_convert_user_prompt_part(part))
elif isinstance(part, ToolReturnPart):
# Tool returns are merged into the tool call in the assistant message
pass
elif isinstance(part, RetryPromptPart):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

RetryPromptPart can occur without a tool_name, and we should send it like a text message.

if part.tool_name:
# Tool-related retries are handled when processing ToolCallPart in ModelResponse
pass
else:
# Non-tool retries (e.g., output validation errors) become user text
user_ui_parts.append(TextUIPart(text=part.model_response(), state='done'))
else:
assert_never(part)

return system_ui_parts, user_ui_parts

@staticmethod
def _dump_response_message( # noqa: C901
msg: ModelResponse,
tool_results: dict[str, ToolReturnPart | RetryPromptPart],
) -> list[UIMessagePart]:
"""Convert a ModelResponse into a UIMessage."""
ui_parts: list[UIMessagePart] = []

# For builtin tools, returns can be in the same ModelResponse as calls
local_builtin_returns: dict[str, BuiltinToolReturnPart] = {
part.tool_call_id: part for part in msg.parts if isinstance(part, BuiltinToolReturnPart)
}

for part in msg.parts:
if isinstance(part, BuiltinToolReturnPart):
continue
elif isinstance(part, TextPart):
# Combine consecutive text parts
if ui_parts and isinstance(ui_parts[-1], TextUIPart):
ui_parts[-1].text += part.content
else:
ui_parts.append(TextUIPart(text=part.content, state='done'))
elif isinstance(part, ThinkingPart):
thinking_metadata: dict[str, Any] = {}
if part.id is not None:
thinking_metadata['id'] = part.id
if part.signature is not None:
thinking_metadata['signature'] = part.signature
if part.provider_name is not None:
thinking_metadata['provider_name'] = part.provider_name
if part.provider_details is not None:
thinking_metadata['provider_details'] = part.provider_details

provider_metadata = {'pydantic_ai': thinking_metadata} if thinking_metadata else None
ui_parts.append(ReasoningUIPart(text=part.content, state='done', provider_metadata=provider_metadata))
elif isinstance(part, FilePart):
ui_parts.append(
FileUIPart(
url=part.content.data_uri,
media_type=part.content.media_type,
)
)
elif isinstance(part, BuiltinToolCallPart):
call_provider_metadata = (
{'pydantic_ai': {'provider_name': part.provider_name}} if part.provider_name else None
)

if builtin_return := local_builtin_returns.get(part.tool_call_id):
content = builtin_return.model_response_str()
ui_parts.append(
ToolOutputAvailablePart(
type=f'tool-{part.tool_name}',
tool_call_id=part.tool_call_id,
input=part.args_as_json_str(),
output=content,
state='output-available',
provider_executed=True,
call_provider_metadata=call_provider_metadata,
)
)
else: # pragma: no cover
ui_parts.append(
ToolInputAvailablePart(
type=f'tool-{part.tool_name}',
tool_call_id=part.tool_call_id,
input=part.args_as_json_str(),
state='input-available',
provider_executed=True,
call_provider_metadata=call_provider_metadata,
)
)
elif isinstance(part, ToolCallPart):
tool_result = tool_results.get(part.tool_call_id)

if isinstance(tool_result, ToolReturnPart):
content = tool_result.model_response_str()
ui_parts.append(
DynamicToolOutputAvailablePart(
tool_name=part.tool_name,
tool_call_id=part.tool_call_id,
input=part.args_as_json_str(),
output=content,
state='output-available',
)
)
elif isinstance(tool_result, RetryPromptPart):
error_text = tool_result.model_response()
ui_parts.append(
DynamicToolOutputErrorPart(
tool_name=part.tool_name,
tool_call_id=part.tool_call_id,
input=part.args_as_json_str(),
error_text=error_text,
state='output-error',
)
)
else:
ui_parts.append(
DynamicToolInputAvailablePart(
tool_name=part.tool_name,
tool_call_id=part.tool_call_id,
input=part.args_as_json_str(),
state='input-available',
)
)
else:
assert_never(part)

return ui_parts

@classmethod
def dump_messages(
cls,
messages: Sequence[ModelMessage],
) -> list[UIMessage]:
"""Transform Pydantic AI messages into Vercel AI messages.

Args:
messages: A sequence of ModelMessage objects to convert

Returns:
A list of UIMessage objects in Vercel AI format
"""
tool_results: dict[str, ToolReturnPart | RetryPromptPart] = {}

for msg in messages:
if isinstance(msg, ModelRequest):
for part in msg.parts:
if isinstance(part, ToolReturnPart):
tool_results[part.tool_call_id] = part
elif isinstance(part, RetryPromptPart) and part.tool_name:
tool_results[part.tool_call_id] = part

result: list[UIMessage] = []

for msg in messages:
if isinstance(msg, ModelRequest):
system_ui_parts, user_ui_parts = cls._dump_request_message(msg)
if system_ui_parts:
result.append(UIMessage(id=str(uuid.uuid4()), role='system', parts=system_ui_parts))

if user_ui_parts:
result.append(UIMessage(id=str(uuid.uuid4()), role='user', parts=user_ui_parts))

elif isinstance( # pragma: no branch
msg, ModelResponse
):
ui_parts: list[UIMessagePart] = cls._dump_response_message(msg, tool_results)
if ui_parts: # pragma: no branch
result.append(UIMessage(id=str(uuid.uuid4()), role='assistant', parts=ui_parts))
else:
assert_never(msg)

return result


def _convert_user_prompt_part(part: UserPromptPart) -> list[UIMessagePart]:
"""Convert a UserPromptPart to a list of UI message parts."""
ui_parts: list[UIMessagePart] = []

if isinstance(part.content, str):
ui_parts.append(TextUIPart(text=part.content, state='done'))
else:
for item in part.content:
if isinstance(item, str):
ui_parts.append(TextUIPart(text=item, state='done'))
elif isinstance(item, BinaryContent):
ui_parts.append(FileUIPart(url=item.data_uri, media_type=item.media_type))
elif isinstance(item, ImageUrl | AudioUrl | VideoUrl | DocumentUrl):
ui_parts.append(FileUIPart(url=item.url, media_type=item.media_type))
elif isinstance(item, CachePoint):
# CachePoint is metadata for prompt caching, skip for UI conversion
pass
else:
assert_never(item)

return ui_parts
13 changes: 13 additions & 0 deletions tests/test_ui.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,10 @@ class DummyUIAdapter(UIAdapter[DummyUIRunInput, ModelMessage, str, AgentDepsT, O
def build_run_input(cls, body: bytes) -> DummyUIRunInput:
return DummyUIRunInput.model_validate_json(body)

@classmethod
def dump_messages(cls, messages: Sequence[ModelMessage]) -> list[ModelMessage]:
return list(messages)

@classmethod
def load_messages(cls, messages: Sequence[ModelMessage]) -> list[ModelMessage]:
return list(messages)
Expand Down Expand Up @@ -676,3 +680,12 @@ async def send(data: MutableMapping[str, Any]) -> None:
{'type': 'http.response.body', 'body': b'', 'more_body': False},
]
)


def test_dummy_adapter_dump_messages():
"""Test that DummyUIAdapter.dump_messages returns messages as-is."""
from pydantic_ai.messages import UserPromptPart

messages = [ModelRequest(parts=[UserPromptPart(content='Hello')])]
result = DummyUIAdapter.dump_messages(messages)
assert result == messages
Loading