Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_agent_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -496,6 +496,10 @@ async def _prepare_request(
model_request_parameters = await _prepare_request_parameters(ctx)

model_settings = ctx.deps.model_settings
# Record metadata on the ModelRequest (the last request in the original history)
self.request.model_request_parameters = model_request_parameters
self.request.model_settings = model_settings
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think we should do this in Model.prepare_request instead, as model classes sometimes still perform their own modifications and we want to make sure we see catch the final values that were actually sent to the LLM.


usage = ctx.state.usage
if ctx.deps.usage_limits.count_tokens_before_request:
# Copy to avoid modifying the original usage object with the counted usage
Expand Down
48 changes: 48 additions & 0 deletions pydantic_ai_slim/pydantic_ai/_model_request_parameters.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
from __future__ import annotations as _annotations

from dataclasses import dataclass, field
from functools import cached_property
from typing import TYPE_CHECKING, Any

from . import _utils
from .builtin_tools import AbstractBuiltinTool

if TYPE_CHECKING:
from .tools import ToolDefinition
else: # pragma: no cover
ToolDefinition = Any
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need this?


if TYPE_CHECKING:
from ._output import OutputObjectDefinition
from .output import OutputMode

__all__ = ('ModelRequestParameters',)


@dataclass(repr=False, kw_only=True)
class ModelRequestParameters:
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""

function_tools: list[ToolDefinition] = field(default_factory=list)
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)

output_mode: OutputMode = 'text'
output_object: OutputObjectDefinition | None = None
output_tools: list[ToolDefinition] = field(default_factory=list)
prompted_output_template: str | None = None
allow_text_output: bool = True
allow_image_output: bool = False

@cached_property
def tool_defs(self) -> dict[str, ToolDefinition]:
return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}

@cached_property
def prompted_output_instructions(self) -> str | None:
if self.output_mode == 'prompted' and self.prompted_output_template and self.output_object:
from ._output import PromptedOutputSchema

return PromptedOutputSchema.build_instructions(self.prompted_output_template, self.output_object)
return None

__repr__ = _utils.dataclasses_no_defaults_repr
24 changes: 24 additions & 0 deletions pydantic_ai_slim/pydantic_ai/messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,21 @@
from typing_extensions import deprecated

from . import _otel_messages, _utils
from ._model_request_parameters import ModelRequestParameters
from ._utils import generate_tool_call_id as _generate_tool_call_id, now_utc as _now_utc
from .exceptions import UnexpectedModelBehavior
from .settings import ModelSettings
from .usage import RequestUsage

if TYPE_CHECKING:
from .models.instrumented import InstrumentationSettings

ModelRequestParametersField = ModelRequestParameters | None
Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Even though we don't serialize these, if we don't do this the TypeAdapter will try to inspect all the fields leading to issues with some httpx types included in the models.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the only way we can solve that, this should be a comment.

But I wonder if @Viicos has another idea. He's out today but will be back on Monday.

ModelSettingsField = ModelSettings | None
else: # pragma: no cover
ModelRequestParametersField = Any
ModelSettingsField = Any


AudioMediaType: TypeAlias = Literal['audio/wav', 'audio/mpeg', 'audio/ogg', 'audio/flac', 'audio/aiff', 'audio/aac']
ImageMediaType: TypeAlias = Literal['image/jpeg', 'image/png', 'image/gif', 'image/webp']
Expand Down Expand Up @@ -945,6 +953,22 @@ class ModelRequest:
instructions: str | None = None
"""The instructions for the model."""

model_request_parameters: Annotated[ModelRequestParametersField, pydantic.Field(exclude=True, repr=False)] = field(
default=None, repr=False, compare=False
)
"""Full request parameters captured for this request.

Available for introspection during a run. This field is excluded from serialization.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"during a run" implies it'll be removed after the run, but that's not the case

"""

model_settings: Annotated[ModelSettingsField, pydantic.Field(exclude=True, repr=False)] = field(
default=None, repr=False, compare=False
)
"""Effective model settings that were applied to this request.

Available for introspection during a run. This field is excluded from serialization.
"""

kind: Literal['request'] = 'request'
"""Message type identifier, this is available on all parts as a discriminator."""

Expand Down
33 changes: 2 additions & 31 deletions pydantic_ai_slim/pydantic_ai/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,12 +19,11 @@
import httpx
from typing_extensions import TypeAliasType, TypedDict

from .. import _utils
from .._json_schema import JsonSchemaTransformer
from .._output import OutputObjectDefinition, PromptedOutputSchema
from .._model_request_parameters import ModelRequestParameters
from .._output import OutputObjectDefinition
from .._parts_manager import ModelResponsePartsManager
from .._run_context import RunContext
from ..builtin_tools import AbstractBuiltinTool
from ..exceptions import UserError
from ..messages import (
BaseToolCallPart,
Expand All @@ -45,7 +44,6 @@
ToolCallPart,
VideoUrl,
)
from ..output import OutputMode
from ..profiles import DEFAULT_PROFILE, ModelProfile, ModelProfileSpec
from ..providers import Provider, infer_provider
from ..settings import ModelSettings, merge_model_settings
Expand Down Expand Up @@ -308,33 +306,6 @@
"""


@dataclass(repr=False, kw_only=True)
class ModelRequestParameters:
"""Configuration for an agent's request to a model, specifically related to tools and output handling."""

function_tools: list[ToolDefinition] = field(default_factory=list)
builtin_tools: list[AbstractBuiltinTool] = field(default_factory=list)

output_mode: OutputMode = 'text'
output_object: OutputObjectDefinition | None = None
output_tools: list[ToolDefinition] = field(default_factory=list)
prompted_output_template: str | None = None
allow_text_output: bool = True
allow_image_output: bool = False

@cached_property
def tool_defs(self) -> dict[str, ToolDefinition]:
return {tool_def.name: tool_def for tool_def in [*self.function_tools, *self.output_tools]}

@cached_property
def prompted_output_instructions(self) -> str | None:
if self.output_mode == 'prompted' and self.prompted_output_template and self.output_object:
return PromptedOutputSchema.build_instructions(self.prompted_output_template, self.output_object)
return None

__repr__ = _utils.dataclasses_no_defaults_repr


class Model(ABC):
"""Abstract class for a model."""

Expand Down
47 changes: 46 additions & 1 deletion tests/test_messages.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,9 @@
UserPromptPart,
VideoUrl,
)
from pydantic_ai.builtin_tools import ImageGenerationTool
from pydantic_ai.models import ModelRequestParameters, ToolDefinition
from pydantic_ai.settings import ModelSettings

from .conftest import IsDatetime, IsNow, IsStr

Expand Down Expand Up @@ -404,7 +407,7 @@ def test_pre_usage_refactor_messages_deserializable():
content='What is the capital of Mexico?',
timestamp=IsNow(tz=timezone.utc),
)
]
],
),
ModelResponse(
parts=[TextPart(content='Mexico City.')],
Expand Down Expand Up @@ -605,3 +608,45 @@ def test_binary_content_validation_with_optional_identifier():
'identifier': 'foo',
}
)


def test_model_request_tool_tracking_excluded_from_serialization():
"""Test that request metadata is accessible but not serialized."""
tool_def = ToolDefinition(
name='test_tool',
description='A test tool',
parameters_json_schema={'type': 'object', 'properties': {}},
)
output_tool_def = ToolDefinition(
name='request_output',
description='An output tool',
parameters_json_schema={'type': 'object', 'properties': {}},
)

model_request_parameters = ModelRequestParameters(
function_tools=[tool_def],
builtin_tools=[ImageGenerationTool()],
output_tools=[output_tool_def],
)
model_settings = ModelSettings(max_tokens=256)

request = ModelRequest(
parts=[UserPromptPart('test prompt')],
instructions='test instructions',
model_request_parameters=model_request_parameters,
model_settings=model_settings,
)

# Verify the metadata is accessible
assert request.model_request_parameters is model_request_parameters
assert request.model_settings == model_settings
params = request.model_request_parameters
assert params is not None
assert params.function_tools == [tool_def]
assert params.builtin_tools == [ImageGenerationTool()]
assert params.output_tools == [output_tool_def]

# Serialize - fields ARE excluded
serialized = ModelMessagesTypeAdapter.dump_python([request], mode='json')
assert 'model_request_parameters' not in serialized[0]
assert 'model_settings' not in serialized[0]
Loading