From d5ff4a1f68c33b701f3c7e26d32f579e13abc547 Mon Sep 17 00:00:00 2001 From: Mert <62549656+fswair@users.noreply.github.com> Date: Thu, 10 Jul 2025 18:51:46 +0300 Subject: [PATCH 01/89] Add `StructuredDict` for structured outputs with custom JSON schema (#2157) Co-authored-by: Douwe Maan --- docs/api/output.md | 1 + docs/output.md | 34 ++++++++++- pydantic_ai_slim/pydantic_ai/__init__.py | 3 +- pydantic_ai_slim/pydantic_ai/_output.py | 16 ++++-- pydantic_ai_slim/pydantic_ai/_utils.py | 7 ++- pydantic_ai_slim/pydantic_ai/output.py | 66 ++++++++++++++++++++- tests/test_agent.py | 73 +++++++++++++++++++++++- tests/test_examples.py | 4 ++ tests/typed_agent.py | 14 ++++- 9 files changed, 204 insertions(+), 14 deletions(-) diff --git a/docs/api/output.md b/docs/api/output.md index 43027fbc01..135ff597bc 100644 --- a/docs/api/output.md +++ b/docs/api/output.md @@ -9,3 +9,4 @@ - NativeOutput - PromptedOutput - TextOutput + - StructuredDict diff --git a/docs/output.md b/docs/output.md index f32e403d71..caa0c14b0f 100644 --- a/docs/output.md +++ b/docs/output.md @@ -31,7 +31,7 @@ _(This example is complete, it can be run "as is")_ ## Output data {#structured-output} -The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argument that takes one or more types or [output functions](#output-functions). It supports simple scalar types, list and dict types, dataclasses and Pydantic models, as well as type unions -- generally everything supported as type hints in a Pydantic model. You can also pass a list of multiple choices. +The [`Agent`][pydantic_ai.Agent] class constructor takes an `output_type` argument that takes one or more types or [output functions](#output-functions). It supports simple scalar types, list and dict types (including `TypedDict`s and [`StructuredDict`s](#structured-dict)), dataclasses and Pydantic models, as well as type unions -- generally everything supported as type hints in a Pydantic model. You can also pass a list of multiple choices. By default, Pydantic AI leverages the model's tool calling capability to make it return structured data. When multiple output types are specified (in a union or list), each member is registered with the model as a separate output tool in order to reduce the complexity of the schema and maximise the chances a model will respond correctly. This has been shown to work well across a wide range of models. If you'd like to change the names of the output tools, use a model's native structured output feature, or pass the output schema to the model in its [instructions](agents.md#instructions), you can use an [output mode](#output-modes) marker class. @@ -117,7 +117,6 @@ print(result.output) _(This example is complete, it can be run "as is")_ - ### Output functions Instead of plain text or structured data, you may want the output of your agent run to be the result of a function called with arguments provided by the model, for example to further process or validate the data provided through the arguments (with the option to tell the model to try again), or to hand off to another agent. @@ -387,6 +386,37 @@ print(repr(result.output)) _(This example is complete, it can be run "as is")_ +### Custom JSON schema {#structured-dict} + +If it's not feasible to define your desired structured output object using a Pydantic `BaseModel`, dataclass, or `TypedDict`, for example when you get a JSON schema from an external source or generate it dynamically, you can use the [`StructuredDict()`][pydantic_ai.output.StructuredDict] helper function to generate a `dict[str, Any]` subclass with a JSON schema attached that Pydantic AI will pass to the model. + +Note that Pydantic AI will not perform any validation of the received JSON object and it's up to the model to correctly interpret the schema and any constraints expressed in it, like required fields or integer value ranges. + +The output type will be a `dict[str, Any]` and it's up to your code to defensively read from it in case the model made a mistake. You can use an [output validator](#output-validator-functions) to reflect validation errors back to the model and get it to try again. + +Along with the JSON schema, you can optionally pass `name` and `description` arguments to provide additional context to the model: + +```python +from pydantic_ai import Agent, StructuredDict + +HumanDict = StructuredDict( + { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + }, + name="Human", + description="A human with a name and age", +) + +agent = Agent('openai:gpt-4o', output_type=HumanDict) +result = agent.run_sync("Create a person") +#> {'name': 'John Doe', 'age': 30} +``` + ### Output validators {#output-validator-functions} Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. diff --git a/pydantic_ai_slim/pydantic_ai/__init__.py b/pydantic_ai_slim/pydantic_ai/__init__.py index aa50774d0e..cd902ea97f 100644 --- a/pydantic_ai_slim/pydantic_ai/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/__init__.py @@ -12,7 +12,7 @@ ) from .format_prompt import format_as_xml from .messages import AudioUrl, BinaryContent, DocumentUrl, ImageUrl, VideoUrl -from .output import NativeOutput, PromptedOutput, TextOutput, ToolOutput +from .output import NativeOutput, PromptedOutput, StructuredDict, TextOutput, ToolOutput from .tools import RunContext, Tool __all__ = ( @@ -46,6 +46,7 @@ 'NativeOutput', 'PromptedOutput', 'TextOutput', + 'StructuredDict', # format_prompt 'format_as_xml', ) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 17f686f4b3..1922f03804 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -264,10 +264,16 @@ def _build_tools( output = output.output + description = description or default_description + if strict is None: + strict = default_strict + + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) + if name is None: name = default_name if multiple: - name += f'_{output.__name__}' + name += f'_{processor.object_def.name}' i = 1 original_name = name @@ -275,11 +281,6 @@ def _build_tools( i += 1 name = f'{original_name}_{i}' - description = description or default_description - if strict is None: - strict = default_strict - - processor = ObjectOutputProcessor(output=output, description=description, strict=strict) tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) return tools @@ -616,6 +617,9 @@ def __init__( # including `response_data_typed_dict` as a title here doesn't add anything and could confuse the LLM json_schema.pop('title') + if name is None and (json_schema_title := json_schema.get('title', None)): + name = json_schema_title + if json_schema_description := json_schema.pop('description', None): if description is None: description = json_schema_description diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index 8034756eca..d3f42a7ee9 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -60,7 +60,12 @@ def is_model_like(type_: Any) -> bool: return ( isinstance(type_, type) and not isinstance(type_, GenericAlias) - and (issubclass(type_, BaseModel) or is_dataclass(type_) or is_typeddict(type_)) # pyright: ignore[reportUnknownArgumentType] + and ( + issubclass(type_, BaseModel) + or is_dataclass(type_) # pyright: ignore[reportUnknownArgumentType] + or is_typeddict(type_) # pyright: ignore[reportUnknownArgumentType] + or getattr(type_, '__is_model_like__', False) # pyright: ignore[reportUnknownArgumentType] + ) ) diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 246823292a..9dc7d2ef6b 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -2,10 +2,14 @@ from collections.abc import Awaitable, Sequence from dataclasses import dataclass -from typing import Callable, Generic, Literal, Union +from typing import Any, Callable, Generic, Literal, Union +from pydantic import GetCoreSchemaHandler, GetJsonSchemaHandler +from pydantic.json_schema import JsonSchemaValue +from pydantic_core import core_schema from typing_extensions import TypeAliasType, TypeVar +from . import _utils from .tools import RunContext __all__ = ( @@ -14,6 +18,7 @@ 'NativeOutput', 'PromptedOutput', 'TextOutput', + 'StructuredDict', # types 'OutputDataT', 'OutputMode', @@ -266,6 +271,65 @@ def split_into_words(text: str) -> list[str]: """The function that will be called to process the model's plain text output. The function must take a single string argument.""" +def StructuredDict( + json_schema: JsonSchemaValue, name: str | None = None, description: str | None = None +) -> type[JsonSchemaValue]: + """Returns a `dict[str, Any]` subclass with a JSON schema attached that will be used for structured output. + + Args: + json_schema: A JSON schema of type `object` defining the structure of the dictionary content. + name: Optional name of the structured output. If not provided, the `title` field of the JSON schema will be used if it's present. + description: Optional description of the structured output. If not provided, the `description` field of the JSON schema will be used if it's present. + + Example: + ```python {title="structured_dict.py"} + from pydantic_ai import Agent, StructuredDict + + + schema = { + "type": "object", + "properties": { + "name": {"type": "string"}, + "age": {"type": "integer"} + }, + "required": ["name", "age"] + } + + agent = Agent('openai:gpt-4o', output_type=StructuredDict(schema)) + result = agent.run_sync("Create a person") + print(result.output) + #> {'name': 'John Doe', 'age': 30} + ``` + """ + json_schema = _utils.check_object_json_schema(json_schema) + + if name: + json_schema['title'] = name + + if description: + json_schema['description'] = description + + class _StructuredDict(JsonSchemaValue): + __is_model_like__ = True + + @classmethod + def __get_pydantic_core_schema__( + cls, source_type: Any, handler: GetCoreSchemaHandler + ) -> core_schema.CoreSchema: + return core_schema.dict_schema( + keys_schema=core_schema.str_schema(), + values_schema=core_schema.any_schema(), + ) + + @classmethod + def __get_pydantic_json_schema__( + cls, core_schema: core_schema.CoreSchema, handler: GetJsonSchemaHandler + ) -> JsonSchemaValue: + return json_schema + + return _StructuredDict + + OutputSpec = TypeAliasType( 'OutputSpec', Union[ diff --git a/tests/test_agent.py b/tests/test_agent.py index 9c7c363f6e..cc2985e198 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -41,7 +41,7 @@ ) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import ToolOutput +from pydantic_ai.output import StructuredDict, ToolOutput from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition @@ -1266,6 +1266,77 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: ) +def test_output_type_structured_dict(): + PersonDict = StructuredDict( + { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'age': {'type': 'integer'}, + }, + 'required': ['name', 'age'], + }, + name='Person', + description='A person', + ) + AnimalDict = StructuredDict( + { + 'type': 'object', + 'properties': { + 'name': {'type': 'string'}, + 'species': {'type': 'string'}, + }, + 'required': ['name', 'species'], + }, + name='Animal', + description='An animal', + ) + + output_tools = None + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + nonlocal output_tools + output_tools = info.output_tools + + args_json = '{"name": "John Doe", "age": 30}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + agent = Agent( + FunctionModel(call_tool), + output_type=[PersonDict, AnimalDict], + ) + + result = agent.run_sync('Generate a person') + + assert result.output == snapshot({'name': 'John Doe', 'age': 30}) + assert output_tools == snapshot( + [ + ToolDefinition( + name='final_result_Person', + parameters_json_schema={ + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, + 'required': ['name', 'age'], + 'title': 'Person', + 'type': 'object', + }, + description='A person', + ), + ToolDefinition( + name='final_result_Animal', + parameters_json_schema={ + 'properties': {'name': {'type': 'string'}, 'species': {'type': 'string'}}, + 'required': ['name', 'species'], + 'title': 'Animal', + 'type': 'object', + }, + description='An animal', + ), + ] + ) + + def test_default_structured_output_mode(): def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: return ModelResponse(parts=[TextPart(content='hello')]) # pragma: no cover diff --git a/tests/test_examples.py b/tests/test_examples.py index b9e232fbc6..c5c274c145 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -444,6 +444,10 @@ async def list_tools() -> list[None]: 'What is a Ford Explorer?': '{"result": {"kind": "Vehicle", "data": {"name": "Ford Explorer", "wheels": 4}}}', 'What is a MacBook?': '{"result": {"kind": "Device", "data": {"name": "MacBook", "kind": "laptop"}}}', 'Write a creative story about space exploration': 'In the year 2157, Captain Maya Chen piloted her spacecraft through the vast expanse of the Andromeda Galaxy. As she discovered a planet with crystalline mountains that sang in harmony with the cosmic winds, she realized that space exploration was not just about finding new worlds, but about finding new ways to understand the universe and our place within it.', + 'Create a person': ToolCallPart( + tool_name='final_result', + args={'name': 'John Doe', 'age': 30}, + ), } tool_responses: dict[tuple[str, str], str] = { diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 941bbf9877..6ea4c4c223 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -4,13 +4,13 @@ from collections.abc import Awaitable from dataclasses import dataclass from decimal import Decimal -from typing import Callable, TypeAlias, Union +from typing import Any, Callable, TypeAlias, Union from typing_extensions import assert_type from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.agent import AgentRunResult -from pydantic_ai.output import TextOutput, ToolOutput +from pydantic_ai.output import StructuredDict, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition # Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True @@ -170,6 +170,16 @@ def run_sync3() -> None: union_agent2: Agent[None, MyUnion] = Agent(output_type=MyUnion) # type: ignore[call-overload] assert_type(union_agent2, Agent[None, MyUnion]) +structured_dict = StructuredDict( + { + 'type': 'object', + 'properties': {'name': {'type': 'string'}, 'age': {'type': 'integer'}}, + 'required': ['name', 'age'], + } +) +structured_dict_agent = Agent(output_type=structured_dict) +assert_type(structured_dict_agent, Agent[None, dict[str, Any]]) + def foobar_ctx(ctx: RunContext[int], x: str, y: int) -> Decimal: return Decimal(x) + y From ecbd08bd45c7de4ffc9ed07f6b3ba96fc6e13125 Mon Sep 17 00:00:00 2001 From: Andrey Golovizin Date: Thu, 10 Jul 2025 18:05:50 +0000 Subject: [PATCH 02/89] Fix type annotations for `Agent.iter()` (#2168) --- pydantic_ai_slim/pydantic_ai/agent.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 0686090f11..9c87fee517 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -500,7 +500,7 @@ async def main(): @overload def iter( self, - user_prompt: str | Sequence[_messages.UserContent] | None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: None = None, message_history: list[_messages.ModelMessage] | None = None, @@ -516,7 +516,7 @@ def iter( @overload def iter( self, - user_prompt: str | Sequence[_messages.UserContent] | None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, output_type: OutputSpec[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, @@ -533,7 +533,7 @@ def iter( @deprecated('`result_type` is deprecated, use `output_type` instead.') def iter( self, - user_prompt: str | Sequence[_messages.UserContent] | None, + user_prompt: str | Sequence[_messages.UserContent] | None = None, *, result_type: type[RunOutputDataT], message_history: list[_messages.ModelMessage] | None = None, From b14ce919fe391de671864873a0ed70e5287b5534 Mon Sep 17 00:00:00 2001 From: Itay Date: Thu, 10 Jul 2025 21:06:21 +0300 Subject: [PATCH 03/89] Fix chat-app example doc - python code appear twice (#2169) --- docs/examples/chat-app.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/examples/chat-app.md b/docs/examples/chat-app.md index ff200ec6a3..b72cc32527 100644 --- a/docs/examples/chat-app.md +++ b/docs/examples/chat-app.md @@ -33,7 +33,7 @@ Python code that runs the chat app: Simple HTML page to render the app: -```snippet {path="/examples/pydantic_ai_examples/chat_app.py"}``` +```snippet {path="/examples/pydantic_ai_examples/chat_app.html"}``` TypeScript to handle rendering the messages, to keep this simple (and at the risk of offending frontend developers) the typescript code is passed to the browser as plain text and transpiled in the browser. From 6d0e850ea14e419f3383c7f0444f3d7c6dfd38df Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 10 Jul 2025 11:17:26 -0700 Subject: [PATCH 04/89] fasta2a as `known-third-party` (#2176) --- pyproject.toml | 12 +++++++----- 1 file changed, 7 insertions(+), 5 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index d122c82a38..162448cc56 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -125,7 +125,6 @@ extend-select = [ "TID251", ] flake8-quotes = { inline-quotes = "single", multiline-quotes = "double" } -isort = { combine-as-imports = true, known-first-party = ["pydantic_ai"] } mccabe = { max-complexity = 15 } ignore = [ "D100", # ignore missing docstring in module @@ -135,6 +134,12 @@ ignore = [ "D107", # ignore missing docstring in __init__ methods ] +[tool.ruff.lint.isort] +combine-as-imports = true +known-first-party = ["pydantic_ai"] +# weird issue with ruff thinking fasta2a is still editable +known-third-party = ["fasta2a"] + [tool.ruff.lint.pydocstyle] convention = "google" @@ -187,10 +192,7 @@ files = "tests/typed_agent.py,tests/typed_graph.py" strict = true [tool.pytest.ini_options] -testpaths = [ - "tests", - "docs/.hooks" -] +testpaths = ["tests", "docs/.hooks"] xfail_strict = true filterwarnings = [ "error", From 78f08f851cf993d4b5c3b0169b408ec6c7c60077 Mon Sep 17 00:00:00 2001 From: Thiago Padilha Date: Thu, 10 Jul 2025 18:17:35 +0000 Subject: [PATCH 05/89] Handle DeepSeek reasoning_content in streamed responses (#2174) --- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 + .../test_deepseek_model_thinking_stream.yaml | 258 ++++++++++++++++++ tests/models/test_deepseek.py | 44 ++- 3 files changed, 307 insertions(+), 1 deletion(-) create mode 100644 tests/models/cassettes/test_deepseek/test_deepseek_model_thinking_stream.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index c95b889ed4..795f015cb5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -994,6 +994,12 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if content is not None: yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + # Handle reasoning part of the response, present in DeepSeek models + if reasoning_content := getattr(choice.delta, 'reasoning_content', None): + yield self._parts_manager.handle_thinking_delta( + vendor_part_id='reasoning_content', content=reasoning_content + ) + for dtc in choice.delta.tool_calls or []: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=dtc.index, diff --git a/tests/models/cassettes/test_deepseek/test_deepseek_model_thinking_stream.yaml b/tests/models/cassettes/test_deepseek/test_deepseek_model_thinking_stream.yaml new file mode 100644 index 0000000000..3e06bef37a --- /dev/null +++ b/tests/models/cassettes/test_deepseek/test_deepseek_model_thinking_stream.yaml @@ -0,0 +1,258 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '130' + content-type: + - application/json + host: + - api.deepseek.com + method: POST + parsed_body: + messages: + - content: Hello + role: user + model: deepseek-reasoner + stream: true + stream_options: + include_usage: true + uri: https://api.deepseek.com/chat/completions + response: + body: + string: "data: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"role\":\"assistant\",\"content\":null,\"reasoning_content\":\"\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"H\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"mm\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + the\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + user\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + just\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + said\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \\\"\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"Hello\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"\\\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + It\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'s\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + a\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + simple\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + greeting\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + but\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + wonder\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + if\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + there\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'s\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + more\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + to\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + it\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \ \\n\\n\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"The\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + message\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + is\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + very\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + brief\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + so\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + don\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'t\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + have\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + much\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + context\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + to\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + work\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + with\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + Maybe\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + they\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'re\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + just\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + testing\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + if\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'m\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + responsive\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + or\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + perhaps\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + they\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'re\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + new\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + to\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + chatting\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + with\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + AI\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \ \\n\\n\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + should\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + keep\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + my\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + reply\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + warm\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + and\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + inviting\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + to\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + encourage\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + further\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + conversation\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + A\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + smile\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"y\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + face\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + would\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + help\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + make\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + it\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + friendly\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + Since\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + they\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + didn\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'t\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + specify\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + a\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + need\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'ll\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + leave\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + it\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + open\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"-ended\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + by\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + asking\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + how\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + can\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + help\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + them\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + today\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \ \\n\\n\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"The\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + tone\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + should\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + be\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + cheerful\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + but\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + professional\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + -\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + not\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + too\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + stiff\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + not\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + too\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + casual\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \\\"\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"Hello\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + there\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"!\\\"\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + feels\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + right\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + for\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + a\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + start\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + Adding\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \\\"\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"What\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + can\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + do\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + for\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + you\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + today\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"?\\\"\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + turns\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + it\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + into\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + an\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + invitation\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + rather\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + than\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + just\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + mirror\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"ing\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + their\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + greeting\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + \ \\n\\n\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"I\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'ll\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + avoid\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + assumptions\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + about\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + their\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + gender\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + location\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + or\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + intent\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + since\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + there\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'s\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + zero\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + information\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + If\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + they\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'re\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + just\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + being\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + polite\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\",\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + they\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + might\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + not\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + reply\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + further\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + -\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + and\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + that\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\"'s\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + okay\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\" + too\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":null,\"reasoning_content\":\".\"},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"Hello\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + there\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"!\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + \U0001F60A\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + How\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + can\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + I\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + help\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + you\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\" + today\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"?\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":null}],\"usage\":null}\n\ndata: + {\"id\":\"33be18fc-3842-486c-8c29-dd8e578f7f20\",\"object\":\"chat.completion.chunk\",\"created\":1752169304,\"model\":\"deepseek-reasoner\",\"system_fingerprint\":\"fp_393bca965e_prod0623_fp8_kvcache\",\"choices\":[{\"index\":0,\"delta\":{\"content\":\"\",\"reasoning_content\":null},\"logprobs\":null,\"finish_reason\":\"stop\"}],\"usage\":{\"prompt_tokens\":6,\"completion_tokens\":212,\"total_tokens\":218,\"prompt_tokens_details\":{\"cached_tokens\":0},\"completion_tokens_details\":{\"reasoning_tokens\":198},\"prompt_cache_hit_tokens\":0,\"prompt_cache_miss_tokens\":6}}\n\ndata: + [DONE]\n\n" + headers: + access-control-allow-credentials: + - 'true' + cache-control: + - no-cache + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + vary: + - origin, access-control-request-method, access-control-request-headers + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_deepseek.py b/tests/models/test_deepseek.py index 480031cce2..64da89d097 100644 --- a/tests/models/test_deepseek.py +++ b/tests/models/test_deepseek.py @@ -1,10 +1,24 @@ from __future__ import annotations as _annotations +from typing import Any + import pytest +from dirty_equals import IsListOrTuple from inline_snapshot import snapshot from pydantic_ai import Agent -from pydantic_ai.messages import ModelRequest, ModelResponse, TextPart, ThinkingPart, UserPromptPart +from pydantic_ai.messages import ( + FinalResultEvent, + ModelRequest, + ModelResponse, + PartDeltaEvent, + PartStartEvent, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, + UserPromptPart, +) from pydantic_ai.usage import Usage from ..conftest import IsDatetime, IsStr, try_import @@ -48,3 +62,31 @@ async def test_deepseek_model_thinking_part(allow_model_requests: None, deepseek ), ] ) + + +async def test_deepseek_model_thinking_stream(allow_model_requests: None, deepseek_api_key: str): + deepseek_model = OpenAIModel('deepseek-reasoner', provider=DeepSeekProvider(api_key=deepseek_api_key)) + agent = Agent(model=deepseek_model) + + event_parts: list[Any] = [] + async with agent.iter(user_prompt='Hello') as agent_run: + async for node in agent_run: + if Agent.is_model_request_node(node) or Agent.is_call_tools_node(node): + async with node.stream(agent_run.ctx) as request_stream: + async for event in request_stream: + event_parts.append(event) + + assert event_parts == snapshot( + IsListOrTuple( + positions={ + 0: PartStartEvent(index=0, part=ThinkingPart(content='H')), + 1: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta='mm')), + 2: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=',')), + 198: PartStartEvent(index=1, part=TextPart(content='Hello')), + 199: FinalResultEvent(tool_name=None, tool_call_id=None), + 200: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' there')), + 201: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='!')), + }, + length=211, + ) + ) From 9b1a89bc10c04e77aa1618d062de019d80f37003 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 10 Jul 2025 21:23:16 +0100 Subject: [PATCH 06/89] chore: skip Gemini tests in local runs (#2181) --- docs/graph.md | 2 +- docs/models/google.md | 2 +- tests/test_examples.py | 4 +++- 3 files changed, 5 insertions(+), 3 deletions(-) diff --git a/docs/graph.md b/docs/graph.md index ee465ed73b..a7f35ac403 100644 --- a/docs/graph.md +++ b/docs/graph.md @@ -352,7 +352,7 @@ stateDiagram-v2 Feedback --> [*] ``` -```python {title="genai_email_feedback.py" py="3.10"} +```python {title="genai_email_feedback.py" py="3.10" test="ci_only"} from __future__ import annotations as _annotations from dataclasses import dataclass, field diff --git a/docs/models/google.md b/docs/models/google.md index c7de50f930..4cebbc779b 100644 --- a/docs/models/google.md +++ b/docs/models/google.md @@ -58,7 +58,7 @@ To use Vertex AI, you may need to set up [application default credentials](https If you have the [`gcloud` CLI](https://cloud.google.com/sdk/gcloud) installed and configured, you can use: -```python +```python {test="ci_only"} from pydantic_ai import Agent from pydantic_ai.models.google import GoogleModel from pydantic_ai.providers.google import GoogleProvider diff --git a/tests/test_examples.py b/tests/test_examples.py index c5c274c145..c1dda22a49 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -205,7 +205,9 @@ def print(self, *args: Any, **kwargs: Any) -> None: eval_example.lint_ruff(example) if opt_test.startswith('skip'): - print(opt_test[4:].lstrip(' -') or 'running code skipped') + pytest.skip(opt_test[4:].lstrip(' -') or 'running code skipped') + elif opt_test.startswith('ci_only') and os.environ.get('GITHUB_ACTIONS', '').lower() != 'true': + pytest.skip(opt_test[7:].lstrip(' -') or 'running code skipped in local tests') # pragma: no cover else: test_globals: dict[str, str] = {'__name__': dunder_name} From a5f4b11ce88d001b4b5670b4ed0897c34a8b70d5 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 10 Jul 2025 22:03:35 +0100 Subject: [PATCH 07/89] chore: support pytest-cov and local reporting (#2179) --- pyproject.toml | 14 ++++++++++++++ 1 file changed, 14 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index 162448cc56..9ba1763adb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -224,6 +224,20 @@ omit = [ "pydantic_ai_slim/pydantic_ai/ext/aci.py", # aci-sdk requires Python 3.10+ so cannot be added as an (optional) dependency ] branch = true +# Disable include-ignored warnings as --source is enabled automatically causing a self conflict as per: +# https://github.com/pytest-dev/pytest-cov/issues/532 +# https://github.com/pytest-dev/pytest-cov/issues/369 +# This prevents coverage being generated by pytest-cov which has direct editor support in VS Code, +# making it super useful to check coverage while writing tests. +disable_warnings = ["include-ignored"] + +[tool.coverage.paths] +# Allow CI run assets to be downloaded an replicated locally. +source = [ + ".", + "/home/runner/work/pydantic-ai/pydantic-ai", + "/System/Volumes/Data/home/runner/work/pydantic-ai/pydantic-ai" +] # https://coverage.readthedocs.io/en/latest/config.html#report [tool.coverage.report] From 484022f5c02c3f79aaf9a282fea778c3d58c1dec Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 10 Jul 2025 22:14:59 +0100 Subject: [PATCH 08/89] chore: disable python 3.9 coverage (#2178) --- .github/workflows/ci.yml | 1 + tests/test_direct.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 30c8c45846..ca6620e03d 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -204,6 +204,7 @@ jobs: enable-cache: true - run: uv sync --package pydantic-ai-slim --only-dev + - run: rm coverage/.coverage.*-py3.9-* # Exclude 3.9 coverage as it gets the wrong line numbers, causing invalid failures. - run: uv run coverage combine coverage - run: uv run coverage html --show-contexts --title "PydanticAI coverage for ${{ github.sha }}" diff --git a/tests/test_direct.py b/tests/test_direct.py index baf58e3345..e9a131ea33 100644 --- a/tests/test_direct.py +++ b/tests/test_direct.py @@ -143,7 +143,7 @@ def test_model_request_stream_sync_without_context_manager(): with pytest.raises(RuntimeError, match=expected_error_msg): for _ in stream_cm: - break + break # pragma: no cover def test_model_request_stream_sync_exception_in_stream(): From 2b7899b6af5fb420a7988941df52a6f4f65788c8 Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Fri, 11 Jul 2025 09:34:06 +0200 Subject: [PATCH 09/89] docs: remove note about using GeminiModel instead of GoogleModel (#2184) --- docs/models/google.md | 8 -------- 1 file changed, 8 deletions(-) diff --git a/docs/models/google.md b/docs/models/google.md index 4cebbc779b..2cc35d9c0a 100644 --- a/docs/models/google.md +++ b/docs/models/google.md @@ -13,14 +13,6 @@ pip/uv-add "pydantic-ai-slim[google]" --- -!!! warning "Explicit instantiation required" - You **cannot** currently use `Agent('google-gla:gemini-1.5-flash')` or `Agent('google-vertex:gemini-1.5-flash')` directly with `GoogleModel`. The model name inference will select [`GeminiModel`](../models/gemini.md) instead of `GoogleModel`. - - To use `GoogleModel`, you **must** explicitly instantiate a [`GoogleProvider`][pydantic_ai.providers.google.GoogleProvider] and pass it to - [`GoogleModel`][pydantic_ai.models.google.GoogleModel], then pass the model to [`Agent`][pydantic_ai.Agent]. - ---- - ## Configuration `GoogleModel` lets you use Google's Gemini models through their [Generative Language API](https://ai.google.dev/api/all-methods) (`generativelanguage.googleapis.com`) or [Vertex AI API](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/models) (`*-aiplatform.googleapis.com`). From 01b6b2c9b3d1a29384135137f6c411e1f069a9c6 Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Sun, 13 Jul 2025 19:58:08 -0700 Subject: [PATCH 10/89] Remove unneeded `pragma: lax no cover` (#2177) --- pydantic_ai_slim/pydantic_ai/_output.py | 10 +++++----- pydantic_ai_slim/pydantic_ai/exceptions.py | 4 ++-- pydantic_ai_slim/pydantic_ai/models/anthropic.py | 2 +- pydantic_ai_slim/pydantic_ai/models/bedrock.py | 2 +- pydantic_ai_slim/pydantic_ai/models/cohere.py | 2 +- pydantic_ai_slim/pydantic_ai/models/gemini.py | 4 ++-- pydantic_ai_slim/pydantic_ai/models/google.py | 2 +- pydantic_ai_slim/pydantic_ai/models/groq.py | 2 +- pydantic_ai_slim/pydantic_ai/models/instrumented.py | 2 +- pydantic_ai_slim/pydantic_ai/models/mistral.py | 4 ++-- pydantic_ai_slim/pydantic_ai/models/openai.py | 4 ++-- pydantic_ai_slim/pydantic_ai/providers/google.py | 4 ++-- .../pydantic_ai/providers/google_vertex.py | 2 +- pydantic_ai_slim/pydantic_ai/result.py | 2 +- 14 files changed, 23 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index 1922f03804..bd882bd6d0 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -664,7 +664,7 @@ async def process( ) raise ToolRetryError(m) from e else: - raise # pragma: lax no cover + raise if k := self.outer_typed_dict_key: output = output[k] @@ -679,7 +679,7 @@ async def process( ) raise ToolRetryError(m) from r else: - raise # pragma: lax no cover + raise return output @@ -849,7 +849,7 @@ async def process( ) raise ToolRetryError(m) from r else: - raise # pragma: lax no cover + raise # pragma: no cover return cast(OutputDataT, output) @@ -908,7 +908,7 @@ async def process( ) raise ToolRetryError(m) from e else: - raise # pragma: lax no cover + raise # pragma: no cover except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -918,7 +918,7 @@ async def process( ) raise ToolRetryError(m) from r else: - raise # pragma: lax no cover + raise # pragma: no cover else: return output diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 0783478258..3f57faaf8d 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -4,9 +4,9 @@ import sys if sys.version_info < (3, 11): - from exceptiongroup import ExceptionGroup # pragma: lax no cover + from exceptiongroup import ExceptionGroup else: - ExceptionGroup = ExceptionGroup # pragma: lax no cover + ExceptionGroup = ExceptionGroup __all__ = ( 'ModelRetry', diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 57e1524c14..6aac2931af 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -256,7 +256,7 @@ async def _messages_create( except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover + raise # pragma: no cover def _process_response(self, response: BetaMessage) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index 8b4ceaec84..f16f9d1119 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -663,4 +663,4 @@ async def __anext__(self) -> T: if type(e.__cause__) is StopIteration: raise StopAsyncIteration else: - raise e # pragma: lax no cover + raise e # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/cohere.py b/pydantic_ai_slim/pydantic_ai/models/cohere.py index b8acce375e..e59c228f24 100644 --- a/pydantic_ai_slim/pydantic_ai/models/cohere.py +++ b/pydantic_ai_slim/pydantic_ai/models/cohere.py @@ -183,7 +183,7 @@ async def _chat( except ApiError as e: if (status_code := e.status_code) and status_code >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover + raise # pragma: no cover def _process_response(self, response: ChatResponse) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index e9b1add322..99aa99a301 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -253,7 +253,7 @@ async def _make_request( if gemini_labels := model_settings.get('gemini_labels'): if self._system == 'google-vertex': - request_data['labels'] = gemini_labels # pragma: lax no cover + request_data['labels'] = gemini_labels headers = {'Content-Type': 'application/json', 'User-Agent': get_user_agent()} url = f'/{self._model_name}:{"streamGenerateContent" if streamed else "generateContent"}' @@ -415,7 +415,7 @@ def _settings_to_generation_config(model_settings: GeminiModelSettings) -> _Gemi if (frequency_penalty := model_settings.get('frequency_penalty')) is not None: config['frequency_penalty'] = frequency_penalty if (thinkingConfig := model_settings.get('gemini_thinking_config')) is not None: - config['thinking_config'] = thinkingConfig # pragma: lax no cover + config['thinking_config'] = thinkingConfig return config diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index b7fe331de6..3755cc16e8 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -166,7 +166,7 @@ def __init__( self._model_name = model_name if isinstance(provider, str): - provider = GoogleProvider(vertexai=provider == 'google-vertex') # pragma: lax no cover + provider = GoogleProvider(vertexai=provider == 'google-vertex') self._provider = provider self._system = provider.name diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index be110771c6..bfdb1d3792 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -248,7 +248,7 @@ async def _completions_create( except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover + raise # pragma: no cover def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index 0bc2c13418..f40340998b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -138,7 +138,7 @@ def __init__( **tokens_histogram_kwargs, explicit_bucket_boundaries_advisory=TOKEN_HISTOGRAM_BOUNDARIES, ) - except TypeError: # pragma: lax no cover + except TypeError: # Older OTel/logfire versions don't support explicit_bucket_boundaries_advisory self.tokens_histogram = self.meter.create_histogram( **tokens_histogram_kwargs, # pyright: ignore diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 4a29c0b7d5..05b90e3142 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -75,7 +75,7 @@ from mistralai.models.usermessage import UserMessage as MistralUserMessage from mistralai.types.basemodel import Unset as MistralUnset from mistralai.utils.eventstreaming import EventStreamAsync as MistralEventStreamAsync -except ImportError as e: # pragma: lax no cover +except ImportError as e: # pragma: no cover raise ImportError( 'Please install `mistral` to use the Mistral model, ' 'you can use the `mistral` optional group — `pip install "pydantic-ai-slim[mistral]"`' @@ -217,7 +217,7 @@ async def _completions_create( except SDKError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover + raise # pragma: no cover assert response, 'A unexpected empty response from Mistral.' return response diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 795f015cb5..b968ac61fd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -345,7 +345,7 @@ async def _completions_create( except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover + raise # pragma: no cover def _process_response(self, response: chat.ChatCompletion) -> ModelResponse: """Process a non-streamed response, and prepare a message to return.""" @@ -781,7 +781,7 @@ async def _responses_create( except APIStatusError as e: if (status_code := e.status_code) >= 400: raise ModelHTTPError(status_code=status_code, model_name=self.model_name, body=e.body) from e - raise # pragma: lax no cover + raise # pragma: no cover def _get_reasoning(self, model_settings: OpenAIResponsesModelSettings) -> Reasoning | NotGiven: reasoning_effort = model_settings.get('openai_reasoning_effort', None) diff --git a/pydantic_ai_slim/pydantic_ai/providers/google.py b/pydantic_ai_slim/pydantic_ai/providers/google.py index bd49b2b350..fc876fcff4 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google.py @@ -86,7 +86,7 @@ def __init__( # NOTE: We are keeping GEMINI_API_KEY for backwards compatibility. api_key = api_key or os.getenv('GOOGLE_API_KEY') or os.getenv('GEMINI_API_KEY') - if vertexai is None: # pragma: lax no cover + if vertexai is None: vertexai = bool(location or project or credentials) if not vertexai: @@ -114,7 +114,7 @@ def __init__( http_options={'headers': {'User-Agent': get_user_agent()}}, ) else: - self._client = client # pragma: lax no cover + self._client = client VertexAILocation = Literal[ diff --git a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py index ec436d7643..cc09eb5c03 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py @@ -50,7 +50,7 @@ def client(self) -> httpx.AsyncClient: return self._client def model_profile(self, model_name: str) -> ModelProfile | None: - return google_model_profile(model_name) # pragma: lax no cover + return google_model_profile(model_name) @overload def __init__( diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 40b59f36a5..0b5c04fa84 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -320,7 +320,7 @@ async def stream(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[Outp yield await self.validate_structured_output(structured_message, allow_partial=not is_last) except ValidationError: if is_last: - raise # pragma: lax no cover + raise # pragma: no cover async def stream_text(self, *, delta: bool = False, debounce_by: float | None = 0.1) -> AsyncIterator[str]: """Stream the text result as an async iterable. From 83c8e01e3d490cd7f77cb5b0e2cb3833f83b12e8 Mon Sep 17 00:00:00 2001 From: Aditya Vardhan <76904033+adtyavrdhn@users.noreply.github.com> Date: Tue, 15 Jul 2025 14:56:56 +0530 Subject: [PATCH 11/89] Fixes for excluding content (#2180) --- pydantic_ai_slim/pydantic_ai/messages.py | 6 +- tests/models/test_instrumented.py | 9 +- tests/test_logfire.py | 108 +++++++++++++++++++++++ 3 files changed, 119 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index f31dc1c93f..90816c87f0 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -411,9 +411,9 @@ class UserPromptPart: """Part type identifier, this is available on all parts as a discriminator.""" def otel_event(self, settings: InstrumentationSettings) -> Event: - content: str | list[dict[str, Any] | str] + content: str | list[dict[str, Any] | str] | dict[str, Any] if isinstance(self.content, str): - content = self.content + content = self.content if settings.include_content else {'kind': 'text'} else: content = [] for part in self.content: @@ -743,7 +743,7 @@ def new_event_body(): 'type': 'function', # TODO https://github.com/pydantic/pydantic-ai/issues/888 'function': { 'name': part.tool_name, - 'arguments': part.args, + **({'arguments': part.args} if settings.include_content else {}), }, } ) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index c0befafeff..b952bf7166 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -863,6 +863,7 @@ def test_messages_without_content(document_content: BinaryContent): ModelRequest(parts=[ToolReturnPart('tool', 'tool_return_content', 'tool_call_1')]), ModelRequest(parts=[RetryPromptPart('retry_prompt', tool_name='tool', tool_call_id='tool_call_2')]), ModelRequest(parts=[UserPromptPart(content=['user_prompt2', document_content])]), + ModelRequest(parts=[UserPromptPart('simple text prompt')]), ] settings = InstrumentationSettings(include_content=False) assert [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(messages)] == snapshot( @@ -896,7 +897,7 @@ def test_messages_without_content(document_content: BinaryContent): { 'id': IsStr(), 'type': 'function', - 'function': {'name': 'my_tool', 'arguments': {'a': 13, 'b': 4}}, + 'function': {'name': 'my_tool'}, } ], 'gen_ai.message.index': 3, @@ -922,5 +923,11 @@ def test_messages_without_content(document_content: BinaryContent): 'gen_ai.message.index': 6, 'event.name': 'gen_ai.user.message', }, + { + 'content': {'kind': 'text'}, + 'role': 'user', + 'gen_ai.message.index': 7, + 'event.name': 'gen_ai.user.message', + }, ] ) diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 5815b71870..691b85d9b1 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -415,6 +415,114 @@ class MyOutput: ) +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +def test_instructions_with_structured_output_exclude_content(get_logfire_summary: Callable[[], LogfireSummary]) -> None: + @dataclass + class MyOutput: + content: str + + settings: InstrumentationSettings = InstrumentationSettings(include_content=False) + + my_agent = Agent(model=TestModel(), instructions='Here are some instructions', instrument=settings) + + result = my_agent.run_sync('Hello', output_type=MyOutput) + assert result.output == snapshot(MyOutput(content='a')) + + summary = get_logfire_summary() + assert summary.attributes[0] == snapshot( + { + 'model_name': 'test', + 'agent_name': 'my_agent', + 'logfire.msg': 'my_agent run', + 'logfire.span_type': 'span', + 'gen_ai.usage.input_tokens': 51, + 'gen_ai.usage.output_tokens': 5, + 'all_messages_events': IsJson( + snapshot( + [ + { + 'content': 'Here are some instructions', + 'role': 'system', + 'event.name': 'gen_ai.system.message', + }, + { + 'content': {'kind': 'text'}, + 'role': 'user', + 'gen_ai.message.index': 0, + 'event.name': 'gen_ai.user.message', + }, + { + 'role': 'assistant', + 'tool_calls': [ + { + 'id': IsStr(), + 'type': 'function', + 'function': {'name': 'final_result'}, + } + ], + 'gen_ai.message.index': 1, + 'event.name': 'gen_ai.assistant.message', + }, + { + 'role': 'tool', + 'id': IsStr(), + 'name': 'final_result', + 'gen_ai.message.index': 2, + 'event.name': 'gen_ai.tool.message', + }, + ] + ) + ), + 'final_result': '{"content": "a"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': {'all_messages_events': {'type': 'array'}, 'final_result': {'type': 'object'}}, + } + ) + ), + } + ) + chat_span_attributes = summary.attributes[1] + assert chat_span_attributes['events'] == snapshot( + IsJson( + snapshot( + [ + { + 'content': 'Here are some instructions', + 'role': 'system', + 'gen_ai.system': 'test', + 'event.name': 'gen_ai.system.message', + }, + { + 'event.name': 'gen_ai.user.message', + 'content': {'kind': 'text'}, + 'role': 'user', + 'gen_ai.message.index': 0, + 'gen_ai.system': 'test', + }, + { + 'event.name': 'gen_ai.choice', + 'index': 0, + 'message': { + 'role': 'assistant', + 'tool_calls': [ + { + 'id': IsStr(), + 'type': 'function', + 'function': {'name': 'final_result'}, + } + ], + }, + 'gen_ai.system': 'test', + }, + ] + ) + ) + ) + + def test_instrument_all(): model = TestModel() agent = Agent() From 6179a1fe578ee87dffe2359c238c5e2f096b02c3 Mon Sep 17 00:00:00 2001 From: Mahmoud Mabrouk Date: Tue, 15 Jul 2025 12:23:56 +0200 Subject: [PATCH 12/89] docs: Add Agenta integration documentation (#2192) --- docs/logfire.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/logfire.md b/docs/logfire.md index e8bd12cb58..f0f4a76926 100644 --- a/docs/logfire.md +++ b/docs/logfire.md @@ -255,6 +255,7 @@ The following providers have dedicated documentation on Pydantic AI: - [Patronus AI](https://docs.patronus.ai/docs/percival/pydantic) - [Opik](https://www.comet.com/docs/opik/tracing/integrations/pydantic-ai) - [mlflow](https://mlflow.org/docs/latest/genai/tracing/integrations/listing/pydantic_ai) +- [Agenta](https://docs.agenta.ai/observability/integrations/pydanticai) ## Advanced usage From c325b7c9d80784d467b7b75207d864ba9e11fad7 Mon Sep 17 00:00:00 2001 From: Victorien <65306057+Viicos@users.noreply.github.com> Date: Tue, 15 Jul 2025 12:25:34 +0200 Subject: [PATCH 13/89] Mention agents instead of models in MCP servers documentation (#2160) --- docs/mcp/server.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/mcp/server.md b/docs/mcp/server.md index 04c03e0114..9c0fda72f4 100644 --- a/docs/mcp/server.md +++ b/docs/mcp/server.md @@ -1,6 +1,6 @@ # Server -PydanticAI models can also be used within MCP Servers. +PydanticAI agents can also be used within MCP Servers. ## MCP Server From 11307484e2e5dbc804e66e2990ca4919439fbb27 Mon Sep 17 00:00:00 2001 From: Bas Nijholt Date: Tue, 15 Jul 2025 03:41:04 -0700 Subject: [PATCH 14/89] `duckduckgo-search` is renamed to `ddgs` (#2172) Co-authored-by: Marcelo Trylesinski --- docs/common-tools.md | 2 +- docs/install.md | 2 +- docs/tools.md | 2 +- .../pydantic_ai/common_tools/duckduckgo.py | 7 ++- pydantic_ai_slim/pyproject.toml | 3 +- uv.lock | 46 +++++++++++++------ 6 files changed, 41 insertions(+), 21 deletions(-) diff --git a/docs/common-tools.md b/docs/common-tools.md index 72e97f93e8..3cb4196184 100644 --- a/docs/common-tools.md +++ b/docs/common-tools.md @@ -5,7 +5,7 @@ PydanticAI ships with native tools that can be used to enhance your agent's capa ## DuckDuckGo Search Tool The DuckDuckGo search tool allows you to search the web for information. It is built on top of the -[DuckDuckGo API](https://github.com/deedy5/duckduckgo_search). +[DuckDuckGo API](https://github.com/deedy5/ddgs). ### Installation diff --git a/docs/install.md b/docs/install.md index 6d621ada5f..9b4b473559 100644 --- a/docs/install.md +++ b/docs/install.md @@ -54,7 +54,7 @@ pip/uv-add "pydantic-ai-slim[openai]" * `groq` — installs `groq` [PyPI ↗](https://pypi.org/project/groq){:target="_blank"} * `mistral` — installs `mistralai` [PyPI ↗](https://pypi.org/project/mistralai){:target="_blank"} * `cohere` - installs `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"} -* `duckduckgo` - installs `duckduckgo-search` [PyPI ↗](https://pypi.org/project/duckduckgo-search){:target="_blank"} +* `duckduckgo` - installs `ddgs` [PyPI ↗](https://pypi.org/project/ddgs){:target="_blank"} * `tavily` - installs `tavily-python` [PyPI ↗](https://pypi.org/project/tavily-python){:target="_blank"} See the [models](models/index.md) documentation for information on which optional dependencies are required for each model. diff --git a/docs/tools.md b/docs/tools.md index 740a2174d9..44133f5759 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -732,7 +732,7 @@ If you'd like to use a tool from LangChain's [community tool library](https://py You will need to install the `langchain-community` package and any others required by the tool in question. -Here is how you can use the LangChain `DuckDuckGoSearchRun` tool, which requires the `duckduckgo-search` package: +Here is how you can use the LangChain `DuckDuckGoSearchRun` tool, which requires the `ddgs` package: ```python {test="skip"} from langchain_community.tools import DuckDuckGoSearchRun diff --git a/pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py b/pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py index 5541710afa..6e5e6b9ec0 100644 --- a/pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py +++ b/pydantic_ai_slim/pydantic_ai/common_tools/duckduckgo.py @@ -9,10 +9,13 @@ from pydantic_ai.tools import Tool try: - from duckduckgo_search import DDGS + try: + from ddgs import DDGS + except ImportError: # Fallback for older versions of ddgs + from duckduckgo_search import DDGS except ImportError as _import_error: raise ImportError( - 'Please install `duckduckgo-search` to use the DuckDuckGo search tool, ' + 'Please install `ddgs` to use the DuckDuckGo search tool, ' 'you can use the `duckduckgo` optional group — `pip install "pydantic-ai-slim[duckduckgo]"`' ) from _import_error diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index c923e375b2..6371e0d7c4 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -70,7 +70,7 @@ groq = ["groq>=0.19.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.37.24"] # Tools -duckduckgo = ["duckduckgo-search>=7.0.0"] +duckduckgo = ["ddgs>=9.0.0"] tavily = ["tavily-python>=0.5.0"] # CLI cli = ["rich>=13", "prompt-toolkit>=3", "argcomplete>=3.5.0"] @@ -88,6 +88,7 @@ dev = [ "devtools>=0.12.2", "coverage[toml]>=7.6.2", "dirty-equals>=0.9.0", + "duckduckgo-search>=7.0.0", "inline-snapshot>=0.19.3", "pytest>=8.3.3", "pytest-examples>=0.0.14", diff --git a/uv.lock b/uv.lock index f4ab182ef9..1c540c3f97 100644 --- a/uv.lock +++ b/uv.lock @@ -827,6 +827,20 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/9d/3a/e39436efe51894243ff145a37c4f9a030839b97779ebcc4f13b3ba21c54e/cssselect2-0.7.0-py3-none-any.whl", hash = "sha256:fd23a65bfd444595913f02fc71f6b286c29261e354c41d722ca7a261a49b5969", size = 15586, upload-time = "2022-09-19T12:55:07.56Z" }, ] +[[package]] +name = "ddgs" +version = "9.0.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "lxml" }, + { name = "primp" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/1f/08/0e84549a1d7d5950573f73d7bc5d36f2a00f92ad8e644b59066afd430a6f/ddgs-9.0.0.tar.gz", hash = "sha256:53b47c74a8060457cb02cbb64acdf59655d799ce8e0934e945bcd878fcab3a7f", size = 21795, upload-time = "2025-07-06T15:43:50.862Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/e5/05/bd3ed9a28212b313f5678533152c4d79fbc386e44245ca5eed426d75f019/ddgs-9.0.0-py3-none-any.whl", hash = "sha256:5dd11d666d6caf1cfdbd341579637bb670c4b2f41df5413b76705519d8e7a22c", size = 17944, upload-time = "2025-07-06T15:43:49.564Z" }, +] + [[package]] name = "defusedxml" version = "0.7.1" @@ -928,16 +942,16 @@ wheels = [ [[package]] name = "duckduckgo-search" -version = "7.5.0" +version = "8.1.1" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "lxml" }, { name = "primp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/17/a8/18404f6525aefa80290afa920ed76fbab16472f19015fdb957b7113f3a9e/duckduckgo_search-7.5.0.tar.gz", hash = "sha256:3e28dc5ec9188efa3a7c8532aa05aaf03bb34b79370855760abd55e6051ff79b", size = 24657, upload-time = "2025-02-24T14:50:49.356Z" } +sdist = { url = "https://files.pythonhosted.org/packages/10/ef/07791a05751e6cc9de1dd49fb12730259ee109b18e6d097e25e6c32d5617/duckduckgo_search-8.1.1.tar.gz", hash = "sha256:9da91c9eb26a17e016ea1da26235d40404b46b0565ea86d75a9f78cc9441f935", size = 22868, upload-time = "2025-07-06T15:30:59.73Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/75/21/fc2c821a2c92c021f8f8adf9fb36235d1b49525b7cd953e85624296aab94/duckduckgo_search-7.5.0-py3-none-any.whl", hash = "sha256:6a2d3f12ae29b3e076cd43be61f5f73cd95261e0a0f318fe0ad3648d7a5dff03", size = 20238, upload-time = "2025-02-24T14:50:48.179Z" }, + { url = "https://files.pythonhosted.org/packages/db/72/c027b3b488b1010cf71670032fcf7e681d44b81829d484bb04e31a949a8d/duckduckgo_search-8.1.1-py3-none-any.whl", hash = "sha256:f48adbb06626ee05918f7e0cef3a45639e9939805c4fc179e68c48a12f1b5062", size = 18932, upload-time = "2025-07-06T15:30:58.339Z" }, ] [[package]] @@ -2757,18 +2771,18 @@ wheels = [ [[package]] name = "primp" -version = "0.14.0" +version = "0.15.0" source = { registry = "https://pypi.org/simple" } -sdist = { url = "https://files.pythonhosted.org/packages/95/1e/a063129aed2320b463fd35c5d918d5754e59011698aaf7cf297a610b3380/primp-0.14.0.tar.gz", hash = "sha256:b6f23b2b694118a9d0443b3760698b90afb6f867f8447e71972530f48297992e", size = 112406, upload-time = "2025-02-23T21:36:46.489Z" } +sdist = { url = "https://files.pythonhosted.org/packages/56/0b/a87556189da4de1fc6360ca1aa05e8335509633f836cdd06dd17f0743300/primp-0.15.0.tar.gz", hash = "sha256:1af8ea4b15f57571ff7fc5e282a82c5eb69bc695e19b8ddeeda324397965b30a", size = 113022, upload-time = "2025-04-17T11:41:05.315Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/7f/12/eba13ddbeb5c6df6cf7511aedb5fa4bcb99c0754e88056260dd44aa53929/primp-0.14.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:bd2dfb57feeba21a77a1128b6c6f17856605c4e73edcc05764fb134de4ff014f", size = 3173837, upload-time = "2025-02-23T21:36:40.891Z" }, - { url = "https://files.pythonhosted.org/packages/77/65/3cd25b4f4d0cd9de4f1d95858dcddd7ed082587524294c179c847de18951/primp-0.14.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:31eecb5316f9bd732a7994530b85eb698bf6500d2f6c5c3382dac0353f77084e", size = 2947192, upload-time = "2025-02-23T21:36:38.595Z" }, - { url = "https://files.pythonhosted.org/packages/13/77/f85bc3e31befa9b9bac54bab61beb34ff84a70d20f02b7dcd8abc120120a/primp-0.14.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:11229e65aa5755fdfb535cc03fd64259a06764ad7c22e650fb3bea51400f1d09", size = 3276730, upload-time = "2025-02-23T21:36:36.292Z" }, - { url = "https://files.pythonhosted.org/packages/44/36/bc95049264ee668a5cdaadf77ef711aaa9cb0c4c0a246b27bba9a2f0114c/primp-0.14.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:8f56ca2cd63f9ac75b33bf48129b7e79ade29cf280bc253b17b052afb27d2b9e", size = 3247684, upload-time = "2025-02-23T21:36:31.649Z" }, - { url = "https://files.pythonhosted.org/packages/31/d9/632a70c80dcdd0bb9293cdc7e7543d35e5912325631c3e9f3b7c7d842941/primp-0.14.0-cp38-abi3-manylinux_2_34_armv7l.whl", hash = "sha256:3fb204f67a4b58dc53f3452143121317b474437812662ac0149d332a77ecbe1a", size = 3007835, upload-time = "2025-02-23T21:36:34.05Z" }, - { url = "https://files.pythonhosted.org/packages/dc/ba/07b04b9d404f20ec78449c5974c988a5adf7d4d245a605466486f70d35c3/primp-0.14.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:b0b21e6a599f580137774623009c7f895afab49d6c3d6c9a28344fd2586ebe8a", size = 3413956, upload-time = "2025-02-23T21:36:43.288Z" }, - { url = "https://files.pythonhosted.org/packages/d7/d3/3bee499b4594fce1f8ccede785e517162407fbea1d452c4fb55fe3fb5e81/primp-0.14.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:6549766ece3c7be19e1c16fa9029d3e50fa73628149d88601fcd964af8b44a8d", size = 3595850, upload-time = "2025-02-23T21:36:45.064Z" }, - { url = "https://files.pythonhosted.org/packages/6a/20/042c8ae21d185f2efe61780dfbc01464c982f59626b746d5436c2e4c1e08/primp-0.14.0-cp38-abi3-win_amd64.whl", hash = "sha256:d3ae1ba954ec8d07abb527ccce7bb36633525c86496950ba0178e44a0ea5c891", size = 3143077, upload-time = "2025-02-23T21:36:48.12Z" }, + { url = "https://files.pythonhosted.org/packages/f5/5a/146ac964b99ea7657ad67eb66f770be6577dfe9200cb28f9a95baffd6c3f/primp-0.15.0-cp38-abi3-macosx_10_12_x86_64.whl", hash = "sha256:1b281f4ca41a0c6612d4c6e68b96e28acfe786d226a427cd944baa8d7acd644f", size = 3178914, upload-time = "2025-04-17T11:40:59.558Z" }, + { url = "https://files.pythonhosted.org/packages/bc/8a/cc2321e32db3ce64d6e32950d5bcbea01861db97bfb20b5394affc45b387/primp-0.15.0-cp38-abi3-macosx_11_0_arm64.whl", hash = "sha256:489cbab55cd793ceb8f90bb7423c6ea64ebb53208ffcf7a044138e3c66d77299", size = 2955079, upload-time = "2025-04-17T11:40:57.398Z" }, + { url = "https://files.pythonhosted.org/packages/c3/7b/cbd5d999a07ff2a21465975d4eb477ae6f69765e8fe8c9087dab250180d8/primp-0.15.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:c18b45c23f94016215f62d2334552224236217aaeb716871ce0e4dcfa08eb161", size = 3281018, upload-time = "2025-04-17T11:40:55.308Z" }, + { url = "https://files.pythonhosted.org/packages/1b/6e/a6221c612e61303aec2bcac3f0a02e8b67aee8c0db7bdc174aeb8010f975/primp-0.15.0-cp38-abi3-manylinux_2_34_aarch64.whl", hash = "sha256:e985a9cba2e3f96a323722e5440aa9eccaac3178e74b884778e926b5249df080", size = 3255229, upload-time = "2025-04-17T11:40:47.811Z" }, + { url = "https://files.pythonhosted.org/packages/3b/54/bfeef5aca613dc660a69d0760a26c6b8747d8fdb5a7f20cb2cee53c9862f/primp-0.15.0-cp38-abi3-manylinux_2_34_armv7l.whl", hash = "sha256:6b84a6ffa083e34668ff0037221d399c24d939b5629cd38223af860de9e17a83", size = 3014522, upload-time = "2025-04-17T11:40:50.191Z" }, + { url = "https://files.pythonhosted.org/packages/ac/96/84078e09f16a1dad208f2fe0f8a81be2cf36e024675b0f9eec0c2f6e2182/primp-0.15.0-cp38-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:592f6079646bdf5abbbfc3b0a28dac8de943f8907a250ce09398cda5eaebd260", size = 3418567, upload-time = "2025-04-17T11:41:01.595Z" }, + { url = "https://files.pythonhosted.org/packages/6c/80/8a7a9587d3eb85be3d0b64319f2f690c90eb7953e3f73a9ddd9e46c8dc42/primp-0.15.0-cp38-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:5a728e5a05f37db6189eb413d22c78bd143fa59dd6a8a26dacd43332b3971fe8", size = 3606279, upload-time = "2025-04-17T11:41:03.61Z" }, + { url = "https://files.pythonhosted.org/packages/0c/dd/f0183ed0145e58cf9d286c1b2c14f63ccee987a4ff79ac85acc31b5d86bd/primp-0.15.0-cp38-abi3-win_amd64.whl", hash = "sha256:aeb6bd20b06dfc92cfe4436939c18de88a58c640752cf7f30d9e4ae893cdec32", size = 3149967, upload-time = "2025-04-17T11:41:07.067Z" }, ] [[package]] @@ -3082,7 +3096,7 @@ cohere = [ { name = "cohere", marker = "sys_platform != 'emscripten'" }, ] duckduckgo = [ - { name = "duckduckgo-search" }, + { name = "ddgs" }, ] evals = [ { name = "pydantic-evals" }, @@ -3123,6 +3137,7 @@ dev = [ { name = "diff-cover", version = "9.2.0", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version < '3.9.17'" }, { name = "diff-cover", version = "9.2.3", source = { registry = "https://pypi.org/simple" }, marker = "python_full_version >= '3.9.17'" }, { name = "dirty-equals" }, + { name = "duckduckgo-search" }, { name = "inline-snapshot" }, { name = "pytest" }, { name = "pytest-examples" }, @@ -3139,7 +3154,7 @@ requires-dist = [ { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.37.24" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.13.11" }, - { name = "duckduckgo-search", marker = "extra == 'duckduckgo'", specifier = ">=7.0.0" }, + { name = "ddgs", marker = "extra == 'duckduckgo'", specifier = ">=9.0.0" }, { name = "eval-type-backport", specifier = ">=0.2.0" }, { name = "exceptiongroup", marker = "python_full_version < '3.11'" }, { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, @@ -3173,6 +3188,7 @@ dev = [ { name = "devtools", specifier = ">=0.12.2" }, { name = "diff-cover", specifier = ">=9.2.0" }, { name = "dirty-equals", specifier = ">=0.9.0" }, + { name = "duckduckgo-search", specifier = ">=7.0.0" }, { name = "inline-snapshot", specifier = ">=0.19.3" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-examples", specifier = ">=0.0.14" }, From 41640aa493a43b308ebe0dad33a70fd88b6f5eb3 Mon Sep 17 00:00:00 2001 From: Aditya Vardhan <76904033+adtyavrdhn@users.noreply.github.com> Date: Tue, 15 Jul 2025 16:22:02 +0530 Subject: [PATCH 15/89] Add base64 encoding to `tool_return_ta` (#2186) --- pydantic_ai_slim/pydantic_ai/messages.py | 4 ++- tests/test_agent.py | 41 ++++++++++++++++++++++++ 2 files changed, 44 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 90816c87f0..e61f0e9a11 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -433,7 +433,9 @@ def otel_event(self, settings: InstrumentationSettings) -> Event: __repr__ = _utils.dataclasses_no_defaults_repr -tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter(Any, config=pydantic.ConfigDict(defer_build=True)) +tool_return_ta: pydantic.TypeAdapter[Any] = pydantic.TypeAdapter( + Any, config=pydantic.ConfigDict(defer_build=True, ser_json_bytes='base64', val_json_bytes='base64') +) @dataclass(repr=False) diff --git a/tests/test_agent.py b/tests/test_agent.py index cc2985e198..47a9c01a12 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2934,6 +2934,47 @@ def test_binary_content_all_messages_json(): assert messages == result.all_messages() +def test_tool_return_part_binary_content_serialization(): + """Test that ToolReturnPart can properly serialize BinaryContent.""" + png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' + binary_content = BinaryContent(png_data, media_type='image/png') + + tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123') + + response_str = tool_return.model_response_str() + + assert '"kind":"binary"' in response_str + assert '"media_type":"image/png"' in response_str + assert '"data":"' in response_str + + response_obj = tool_return.model_response_object() + assert response_obj['return_value']['kind'] == 'binary' + assert response_obj['return_value']['media_type'] == 'image/png' + assert 'data' in response_obj['return_value'] + + +def test_tool_returning_binary_content_directly(): + """Test that a tool returning BinaryContent directly works correctly.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart('get_image', {})]) + else: + return ModelResponse(parts=[TextPart('Image received')]) + + agent = Agent(FunctionModel(llm)) + + @agent.tool_plain + def get_image() -> BinaryContent: + """Return a simple image.""" + png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' + return BinaryContent(png_data, media_type='image/png') + + # This should work without the serialization error + result = agent.run_sync('Get an image') + assert result.output == 'Image received' + + def test_instructions_raise_error_when_system_prompt_is_set(): agent = Agent('test', instructions='An instructions!') From 46aa248729506dc44db4daa31494cb7105b6a264 Mon Sep 17 00:00:00 2001 From: Jad Haddad Date: Tue, 15 Jul 2025 12:54:05 +0200 Subject: [PATCH 16/89] Bugfix: avoid race condition when refreshing google token (#2100) Co-authored-by: Marcelo Trylesinski --- .../pydantic_ai/providers/google_vertex.py | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py index cc09eb5c03..77f3cc1a84 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py +++ b/pydantic_ai_slim/pydantic_ai/providers/google_vertex.py @@ -116,6 +116,8 @@ def __init__( class _VertexAIAuth(httpx.Auth): """Auth class for Vertex AI API.""" + _refresh_lock: anyio.Lock = anyio.Lock() + credentials: BaseCredentials | ServiceAccountCredentials | None def __init__( @@ -169,10 +171,13 @@ async def _get_credentials(self) -> BaseCredentials | ServiceAccountCredentials: return creds async def _refresh_token(self) -> str: # pragma: no cover - assert self.credentials is not None - await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType] - assert isinstance(self.credentials.token, str), f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType] - return self.credentials.token + async with self._refresh_lock: + assert self.credentials is not None + await anyio.to_thread.run_sync(self.credentials.refresh, Request()) # type: ignore[reportUnknownMemberType] + assert isinstance(self.credentials.token, str), ( # type: ignore[reportUnknownMemberType] + f'Expected token to be a string, got {self.credentials.token}' # type: ignore[reportUnknownMemberType] + ) + return self.credentials.token async def _async_google_auth() -> tuple[BaseCredentials, str | None]: From c94cc036ced68a9518157141076f35b8e50e98b5 Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Wed, 16 Jul 2025 12:12:36 +0100 Subject: [PATCH 17/89] fix: a2a docs dependency (#2216) --- pyproject.toml | 1 + uv.lock | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index 9ba1763adb..18c7853c70 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -64,6 +64,7 @@ Changelog = "https://github.com/pydantic/pydantic-ai/releases" pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while [tool.uv.sources] +pydantic-ai = { workspace = true } pydantic-ai-slim = { workspace = true } pydantic-evals = { workspace = true } pydantic-graph = { workspace = true } diff --git a/uv.lock b/uv.lock index 1c540c3f97..2b2c5c2c3b 100644 --- a/uv.lock +++ b/uv.lock @@ -3017,7 +3017,7 @@ docs = [ { name = "mkdocs-llmstxt", specifier = ">=0.2.0" }, { name = "mkdocs-material", extras = ["imaging"], specifier = ">=9.5.45" }, { name = "mkdocstrings-python", specifier = ">=1.12.2" }, - { name = "pydantic-ai", extras = ["a2a"] }, + { name = "pydantic-ai", extras = ["a2a"], editable = "." }, ] docs-upload = [ { name = "algoliasearch", specifier = ">=4.12.0" }, From f9f1c03028381ea562c67e7bcf07001e4583df8b Mon Sep 17 00:00:00 2001 From: Nahian-Al Hasan Date: Wed, 16 Jul 2025 21:51:14 +1000 Subject: [PATCH 18/89] feat: Add output function tracing (#2191) Co-authored-by: Alex Hall --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 7 +- pydantic_ai_slim/pydantic_ai/_output.py | 124 +++- pydantic_ai_slim/pydantic_ai/agent.py | 1 + pydantic_ai_slim/pydantic_ai/result.py | 19 +- tests/test_logfire.py | 716 +++++++++++++++++++ 5 files changed, 852 insertions(+), 15 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 4515d18bc9..fda19acda4 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -341,6 +341,7 @@ async def stream( ctx.deps.output_schema, ctx.deps.output_validators, build_run_context(ctx), + _output.build_trace_context(ctx), ctx.deps.usage_limits, ) yield agent_stream @@ -529,7 +530,8 @@ async def _handle_tool_calls( if isinstance(output_schema, _output.ToolOutputSchema): for call, output_tool in output_schema.find_tool(tool_calls): try: - result_data = await output_tool.process(call, run_context) + trace_context = _output.build_trace_context(ctx) + result_data = await output_tool.process(call, run_context, trace_context) result_data = await _validate_output(result_data, ctx, call) except _output.ToolRetryError as e: # TODO: Should only increment retry stuff once per node execution, not for each tool call @@ -586,7 +588,8 @@ async def _handle_text_response( try: if isinstance(output_schema, _output.TextOutputSchema): run_context = build_run_context(ctx) - result_data = await output_schema.process(text, run_context) + trace_context = _output.build_trace_context(ctx) + result_data = await output_schema.process(text, run_context, trace_context) else: m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index bd882bd6d0..c3199dd95c 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,5 +1,6 @@ from __future__ import annotations as _annotations +import dataclasses import inspect import json from abc import ABC, abstractmethod @@ -7,10 +8,13 @@ from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload +from opentelemetry.trace import Tracer from pydantic import TypeAdapter, ValidationError from pydantic_core import SchemaValidator from typing_extensions import TypedDict, TypeVar, assert_never +from pydantic_graph.nodes import GraphRunContext + from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, UserError @@ -29,6 +33,8 @@ from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition if TYPE_CHECKING: + from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState + from .profiles import ModelProfile T = TypeVar('T') @@ -66,6 +72,71 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' +@dataclass(frozen=True) +class TraceContext: + """A context for tracing output processing.""" + + tracer: Tracer + include_content: bool + call: _messages.ToolCallPart | None = None + + def with_call(self, call: _messages.ToolCallPart): + return dataclasses.replace(self, call=call) + + async def execute_function_with_span( + self, + function_schema: _function_schema.FunctionSchema, + run_context: RunContext[AgentDepsT], + args: dict[str, Any] | Any, + call: _messages.ToolCallPart, + include_tool_call_id: bool = True, + ) -> Any: + """Execute a function call within a traced span, automatically recording the response.""" + # Set up span attributes + attributes = { + 'gen_ai.tool.name': call.tool_name, + 'logfire.msg': f'running output function: {call.tool_name}', + } + if include_tool_call_id: + attributes['gen_ai.tool.call.id'] = call.tool_call_id + if self.include_content: + attributes['tool_arguments'] = call.args_as_json_str() + attributes['logfire.json_schema'] = json.dumps( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + + # Execute function within span + with self.tracer.start_as_current_span('running output function', attributes=attributes) as span: + output = await function_schema.call(args, run_context) + + # Record response if content inclusion is enabled + if self.include_content and span.is_recording(): + from .models.instrumented import InstrumentedModel + + span.set_attribute( + 'tool_response', + output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), + ) + + return output + + +def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext: + """Build a `TraceContext` from the current agent graph run context.""" + return TraceContext( + tracer=ctx.deps.tracer, + include_content=( + ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content + ), + ) + + class ToolRetryError(Exception): """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" @@ -96,6 +167,7 @@ async def validate( result: The result data after Pydantic validation the message content. tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. Returns: Result of either the validated result data (ok) or a retry message (Err). @@ -349,6 +421,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -371,6 +444,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -379,6 +453,7 @@ async def process( Args: text: The output text to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -389,7 +464,7 @@ async def process( return cast(OutputDataT, text) return await self.processor.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -417,6 +492,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -425,6 +501,7 @@ async def process( Args: text: The output text to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -432,7 +509,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ return await self.processor.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -468,6 +545,7 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -476,6 +554,7 @@ async def process( Args: text: The output text to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -485,7 +564,7 @@ async def process( text = _utils.strip_markdown_fences(text) return await self.processor.process( - text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -568,6 +647,7 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -637,6 +717,7 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -645,6 +726,7 @@ async def process( Args: data: The output data to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -670,8 +752,18 @@ async def process( output = output[k] if self._function_schema: + # Wraps the output function call in an OpenTelemetry span. + if trace_context.call: + call = trace_context.call + include_tool_call_id = True + else: + function_name = getattr(self._function_schema.function, '__name__', 'output_function') + call = _messages.ToolCallPart(tool_name=function_name, args=data) + include_tool_call_id = False try: - output = await self._function_schema.call(output, run_context) + output = await trace_context.execute_function_with_span( + self._function_schema, run_context, output, call, include_tool_call_id + ) except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -784,11 +876,12 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) result = union_object.result @@ -804,7 +897,7 @@ async def process( raise return await processor.process( - data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -835,13 +928,20 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} - + # Wraps the output function call in an OpenTelemetry span. + # Note: PlainTextOutputProcessor is used for text responses (not tool calls), + # so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id + function_name = getattr(self._function_schema.function, '__name__', 'text_output_function') + call = _messages.ToolCallPart(tool_name=function_name, args=args) try: - output = await self._function_schema.call(args, run_context) + output = await trace_context.execute_function_with_span( + self._function_schema, run_context, args, call, include_tool_call_id=False + ) except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -881,6 +981,7 @@ async def process( self, tool_call: _messages.ToolCallPart, run_context: RunContext[AgentDepsT], + trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -889,6 +990,7 @@ async def process( Args: tool_call: The tool call from the LLM to validate. run_context: The current run context. + trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -897,7 +999,11 @@ async def process( """ try: output = await self.processor.process( - tool_call.args, run_context, allow_partial=allow_partial, wrap_validation_errors=False + tool_call.args, + run_context, + trace_context.with_call(tool_call), + allow_partial=allow_partial, + wrap_validation_errors=False, ) except ValidationError as e: if wrap_validation_errors: diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9c87fee517..3ff881294c 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1089,6 +1089,7 @@ async def on_complete() -> None: streamed_response, graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), + _output.build_trace_context(graph_ctx), graph_ctx.deps.output_validators, final_result_details.tool_name, on_complete, diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 0b5c04fa84..f700482662 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -19,6 +19,7 @@ PlainTextOutputSchema, TextOutputSchema, ToolOutputSchema, + TraceContext, ) from ._run_context import AgentDepsT, RunContext from .messages import AgentStreamEvent, FinalResultEvent @@ -46,6 +47,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_schema: OutputSchema[OutputDataT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] + _trace_ctx: TraceContext _usage_limits: UsageLimits | None _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) @@ -105,13 +107,17 @@ async def _validate_response( call, output_tool = match result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + call, + self._run_ctx, + self._trace_ctx, + allow_partial=allow_partial, + wrap_validation_errors=False, ) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover @@ -177,6 +183,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] + _trace_ctx: TraceContext _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] @@ -423,13 +430,17 @@ async def validate_structured_output( call, output_tool = match result_data = await output_tool.process( - call, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + call, + self._run_ctx, + self._trace_ctx, + allow_partial=allow_partial, + wrap_validation_errors=False, ) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 691b85d9b1..97ba871cc9 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -6,12 +6,17 @@ import pytest from dirty_equals import IsInt, IsJson, IsList from inline_snapshot import snapshot +from pydantic import BaseModel from typing_extensions import NotRequired, TypedDict from pydantic_ai import Agent from pydantic_ai._utils import get_traceparent +from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart +from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel from pydantic_ai.models.test import TestModel +from pydantic_ai.output import PromptedOutput, TextOutput +from pydantic_ai.tools import RunContext from .conftest import IsStr @@ -705,3 +710,714 @@ async def add_numbers(x: int, y: int) -> int: 'logfire.span_type': 'span', } ) + + +class WeatherInfo(BaseModel): + temperature: float + description: str + + +def get_weather_info(city: str) -> WeatherInfo: + return WeatherInfo(temperature=28.7, description='sunny') + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=get_weather_info) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_with_run_context_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def get_weather_with_ctx(ctx: RunContext[None], city: str) -> WeatherInfo: + assert ctx is not None + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=get_weather_with_ctx) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_with_retry_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def get_weather_with_retry(city: str) -> WeatherInfo: + if city != 'Mexico City': + from pydantic_ai import ModelRetry + + raise ModelRetry('City not found, I only know Mexico City') + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + args_json = '{"city": "New York City"}' + else: + args_json = '{"city": "Mexico City"}' + + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('New York City', output_type=get_weather_with_retry) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + output_function_attributes = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "New York City"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + }, + ] + ) + else: + assert output_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'final_result', + 'logfire.msg': 'running output function: final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.span_type': 'span', + }, + ] + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_function_with_custom_tool_name_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + from pydantic_ai.output import ToolOutput + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=ToolOutput(get_weather_info, name='get_weather')) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes with custom tool name + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'get_weather' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'get_weather', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: get_weather', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'get_weather', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: get_weather', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_bound_instance_method_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, city: str): + return self + + weather = Weather(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=weather.get_weather) + assert result.output == Weather(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_bound_instance_method_with_run_context_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + class Weather(BaseModel): + temperature: float + description: str + + def get_weather(self, ctx: RunContext[None], city: str): + assert ctx is not None + return self + + weather = Weather(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=weather.get_weather) + assert result.output == Weather(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_async_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + """Test logfire attributes for async output function types.""" + + async def get_weather_async(city: str) -> WeatherInfo: + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + args_json = '{"city": "Mexico City"}' + return ModelResponse(parts=[ToolCallPart(info.output_tools[0].name, args_json)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('Mexico City', output_type=get_weather_async) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'final_result' + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"city": "Mexico City"}', + 'logfire.msg': 'running output function: final_result', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'final_result', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running output function: final_result', + 'logfire.span_type': 'span', + } + ) + + +def upcase_text(text: str) -> str: + """Convert text to uppercase.""" + return text.upper() + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_text_output_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + """Test logfire attributes for TextOutput functions (PlainTextOutputProcessor).""" + + def call_text_response(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: + # Return a plain text response (not a tool call) + from pydantic_ai.messages import TextPart + + return ModelResponse(parts=[TextPart(content='hello world')]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_text_response), instrument=instrumentation_settings) + + result = my_agent.run_sync('Say hello', output_type=TextOutput(upcase_text)) + assert result.output == 'HELLO WORLD' + + summary = get_logfire_summary() + + # Find the text output function span attributes + [text_function_attributes] = [ + attributes + for attributes in summary.attributes.values() + if 'running output function: upcase_text' in attributes.get('logfire.msg', '') + ] + + if include_content: + assert text_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'tool_arguments': '{"text":"hello world"}', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': 'HELLO WORLD', + } + ) + else: + assert text_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_prompted_output_function_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + """Test that spans are created for PromptedOutput functions with appropriate attributes.""" + + def upcase_text(text: str) -> str: + return text.upper() + + call_count = 0 + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + nonlocal call_count + call_count += 1 + # Simulate the model returning JSON that will be parsed and used to call the function + return ModelResponse(parts=[TextPart(content='{"text": "hello world"}')]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + agent = Agent( + model=FunctionModel(call_tool), instrument=instrumentation_settings, output_type=PromptedOutput(upcase_text) + ) + + result = agent.run_sync('test') + + # Check that the function was called and returned the expected result + assert result.output == 'HELLO WORLD' + assert call_count == 1 + + summary = get_logfire_summary() + + # Find the output function span attributes + [output_function_attributes] = [ + attributes + for attributes in summary.attributes.values() + if attributes.get('logfire.msg', '').startswith('running output function: upcase_text') + ] + + if include_content: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'tool_arguments': '{"text": "hello world"}', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': 'HELLO WORLD', + } + ) + else: + assert output_function_attributes == snapshot( + { + 'gen_ai.tool.name': 'upcase_text', + 'logfire.msg': 'running output function: upcase_text', + 'logfire.span_type': 'span', + } + ) + + +@pytest.mark.skipif(not logfire_installed, reason='logfire not installed') +@pytest.mark.parametrize('include_content', [True, False]) +def test_output_type_text_output_function_with_retry_logfire_attributes( + get_logfire_summary: Callable[[], LogfireSummary], + include_content: bool, +) -> None: + def get_weather_with_retry(ctx: RunContext[None], city: str) -> WeatherInfo: + assert ctx is not None + if city != 'Mexico City': + from pydantic_ai import ModelRetry + + raise ModelRetry('City not found, I only know Mexico City') + return WeatherInfo(temperature=28.7, description='sunny') + + def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + assert info.output_tools is not None + + if len(messages) == 1: + city = 'New York City' + else: + city = 'Mexico City' + + return ModelResponse(parts=[TextPart(content=city)]) + + instrumentation_settings = InstrumentationSettings(include_content=include_content) + my_agent = Agent(model=FunctionModel(call_tool), instrument=instrumentation_settings) + + result = my_agent.run_sync('New York City', output_type=TextOutput(get_weather_with_retry)) + assert result.output == WeatherInfo(temperature=28.7, description='sunny') + + summary = get_logfire_summary() + + text_function_attributes = [ + attributes + for attributes in summary.attributes.values() + if 'running output function: get_weather_with_retry' in attributes.get('logfire.msg', '') + ] + + if include_content: + assert text_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'tool_arguments': '{"city":"New York City"}', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'tool_arguments': '{"city":"Mexico City"}', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': '{"temperature": 28.7, "description": "sunny"}', + }, + ] + ) + else: + assert text_function_attributes == snapshot( + [ + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + }, + { + 'gen_ai.tool.name': 'get_weather_with_retry', + 'logfire.msg': 'running output function: get_weather_with_retry', + 'logfire.span_type': 'span', + }, + ] + ) From 4d755d2b5799258feffaef80ea892f838298441b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Wed, 16 Jul 2025 16:02:55 +0100 Subject: [PATCH 19/89] Add Hugging Face as a provider (#1911) Co-authored-by: Marcelo Trylesinski Co-authored-by: burtenshaw --- docs/api/models/huggingface.md | 7 + docs/api/providers.md | 2 + docs/models/huggingface.md | 95 ++ mkdocs.yml | 1 + .../pydantic_ai/models/__init__.py | 14 +- .../pydantic_ai/models/huggingface.py | 463 ++++++++ .../pydantic_ai/providers/__init__.py | 4 + .../pydantic_ai/providers/huggingface.py | 88 ++ pydantic_ai_slim/pyproject.toml | 1 + pyproject.toml | 2 +- tests/conftest.py | 14 + .../test_hf_model_instructions.yaml | 121 +++ .../test_hf_model_thinking_part.yaml | 291 +++++ .../test_image_as_binary_content_input.yaml | 106 ++ .../test_image_url_input.yaml | 105 ++ ...ion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml | 122 +++ ..._tokens[deepseek-ai-DeepSeek-R1-0528].yaml | 128 +++ ...ns[meta-llama-Llama-3.3-70B-Instruct].yaml | 142 +++ .../test_request_simple_success_with_vcr.yaml | 126 +++ .../test_request_simple_usage.yaml | 122 +++ .../test_simple_completion.yaml | 122 +++ .../test_stream_completion.yaml | 319 ++++++ tests/models/test_huggingface.py | 999 ++++++++++++++++++ tests/models/test_model_names.py | 3 + tests/providers/test_huggingface.py | 142 +++ tests/test_cli.py | 1 + tests/test_examples.py | 1 + uv.lock | 62 +- 28 files changed, 3583 insertions(+), 20 deletions(-) create mode 100644 docs/api/models/huggingface.md create mode 100644 docs/models/huggingface.md create mode 100644 pydantic_ai_slim/pydantic_ai/models/huggingface.py create mode 100644 pydantic_ai_slim/pydantic_ai/providers/huggingface.py create mode 100644 tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_image_url_input.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml create mode 100644 tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml create mode 100644 tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml create mode 100644 tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_simple_completion.yaml create mode 100644 tests/models/cassettes/test_huggingface/test_stream_completion.yaml create mode 100644 tests/models/test_huggingface.py create mode 100644 tests/providers/test_huggingface.py diff --git a/docs/api/models/huggingface.md b/docs/api/models/huggingface.md new file mode 100644 index 0000000000..72e78c4a3e --- /dev/null +++ b/docs/api/models/huggingface.md @@ -0,0 +1,7 @@ +# `pydantic_ai.models.huggingface` + +## Setup + +For details on how to set up authentication with this model, see [model configuration for Hugging Face](../../models/huggingface.md). + +::: pydantic_ai.models.huggingface diff --git a/docs/api/providers.md b/docs/api/providers.md index ec684520ce..7b2ddc1c12 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -31,3 +31,5 @@ ::: pydantic_ai.providers.github.GitHubProvider ::: pydantic_ai.providers.openrouter.OpenRouterProvider + +::: pydantic_ai.providers.huggingface.HuggingFaceProvider diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md new file mode 100644 index 0000000000..61f8eef35f --- /dev/null +++ b/docs/models/huggingface.md @@ -0,0 +1,95 @@ +# Hugging Face + +[Hugging Face](https://huggingface.co/) is an AI platform with all major open source models, datasets, MCPs, and demos. You can use [Inference Providers](https://huggingface.co/docs/inference-providers) to run open source models like DeepSeek R1 on scalable serverless infrastructure. + +## Install + +To use `HuggingFaceModel`, you need to either install `pydantic-ai`, or install `pydantic-ai-slim` with the `huggingface` optional group: + +```bash +pip/uv-add "pydantic-ai-slim[huggingface]" +``` + +## Configuration + +To use [Hugging Face](https://huggingface.co/) inference, you'll need to set up an account which will give you [free tier](https://huggingface.co/docs/inference-providers/pricing) allowance on [Inference Providers](https://huggingface.co/docs/inference-providers). To setup inference, follow these steps: + +1. Go to [Hugging Face](https://huggingface.co/join) and sign up for an account. +2. Create a new access token in [Hugging Face](https://huggingface.co/settings/tokens). +3. Set the `HF_TOKEN` environment variable to the token you just created. + +Once you have a Hugging Face access token, you can set it as an environment variable: + +```bash +export HF_TOKEN='hf_token' +``` + +## Usage + +You can then use [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] by name: + +```python +from pydantic_ai import Agent + +agent = Agent('huggingface:Qwen/Qwen3-235B-A22B') +... +``` + +Or initialise the model directly with just the model name: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B') +agent = Agent(model) +... +``` + +By default, the [`HuggingFaceModel`][pydantic_ai.models.huggingface.HuggingFaceModel] uses the +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] that will select automatically +the first of the inference providers (Cerebras, Together AI, Cohere..etc) available for the model, sorted by your +preferred order in https://hf.co/settings/inference-providers. + +## Configure the provider + +If you want to pass parameters in code to the provider, you can programmatically instantiate the +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] and pass it to the model: + +```python +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +model = HuggingFaceModel('Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(api_key='hf_token', provider_name='nebius')) +agent = Agent(model) +... +``` + +## Custom Hugging Face client + +[`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] also accepts a custom +[`AsyncInferenceClient`][huggingface_hub.AsyncInferenceClient] client via the `hf_client` parameter, so you can customise +the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the +[Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). + +```python +from huggingface_hub import AsyncInferenceClient + +from pydantic_ai import Agent +from pydantic_ai.models.huggingface import HuggingFaceModel +from pydantic_ai.providers.huggingface import HuggingFaceProvider + +client = AsyncInferenceClient( + bill_to='openai', + api_key='hf_token', + provider='fireworks-ai', +) + +model = HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', + provider=HuggingFaceProvider(hf_client=client), +) +agent = Agent(model) +... +``` diff --git a/mkdocs.yml b/mkdocs.yml index 44b1548f1a..a950d52c0c 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -83,6 +83,7 @@ nav: - api/models/gemini.md - api/models/google.md - api/models/groq.md + - api/models/huggingface.md - api/models/instrumented.md - api/models/mistral.md - api/models/test.md diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 79f6031687..811c128379 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -227,6 +227,14 @@ 'heroku:claude-3-7-sonnet', 'heroku:claude-4-sonnet', 'heroku:claude-3-haiku', + 'huggingface:Qwen/QwQ-32B', + 'huggingface:Qwen/Qwen2.5-72B-Instruct', + 'huggingface:Qwen/Qwen3-235B-A22B', + 'huggingface:Qwen/Qwen3-32B', + 'huggingface:deepseek-ai/DeepSeek-R1', + 'huggingface:meta-llama/Llama-3.3-70B-Instruct', + 'huggingface:meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'huggingface:meta-llama/Llama-4-Scout-17B-16E-Instruct', 'mistral:codestral-latest', 'mistral:mistral-large-latest', 'mistral:mistral-moderation-latest', @@ -560,7 +568,7 @@ def override_allow_model_requests(allow_model_requests: bool) -> Iterator[None]: ALLOW_MODEL_REQUESTS = old_value # pyright: ignore[reportConstantRedefinition] -def infer_model(model: Model | KnownModelName | str) -> Model: +def infer_model(model: Model | KnownModelName | str) -> Model: # noqa: C901 """Infer the model from the name.""" if isinstance(model, Model): return model @@ -624,6 +632,10 @@ def infer_model(model: Model | KnownModelName | str) -> Model: from .bedrock import BedrockConverseModel return BedrockConverseModel(model_name, provider=provider) + elif provider == 'huggingface': + from .huggingface import HuggingFaceModel + + return HuggingFaceModel(model_name, provider=provider) else: raise UserError(f'Unknown model: {model}') # pragma: no cover diff --git a/pydantic_ai_slim/pydantic_ai/models/huggingface.py b/pydantic_ai_slim/pydantic_ai/models/huggingface.py new file mode 100644 index 0000000000..41d53ca62a --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/models/huggingface.py @@ -0,0 +1,463 @@ +from __future__ import annotations as _annotations + +import base64 +from collections.abc import AsyncIterable, AsyncIterator +from contextlib import asynccontextmanager +from dataclasses import dataclass, field +from datetime import datetime, timezone +from typing import Literal, Union, cast, overload + +from typing_extensions import assert_never + +from pydantic_ai._thinking_part import split_content_into_text_and_thinking +from pydantic_ai.providers import Provider, infer_provider + +from .. import ModelHTTPError, UnexpectedModelBehavior, _utils, usage +from .._utils import guard_tool_call_id as _guard_tool_call_id, now_utc as _now_utc +from ..messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + ModelMessage, + ModelRequest, + ModelResponse, + ModelResponsePart, + ModelResponseStreamEvent, + RetryPromptPart, + SystemPromptPart, + TextPart, + ThinkingPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + VideoUrl, +) +from ..settings import ModelSettings +from ..tools import ToolDefinition +from . import Model, ModelRequestParameters, StreamedResponse, check_allow_model_requests + +try: + import aiohttp + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionInputMessageChunk, + ChatCompletionInputTool, + ChatCompletionInputToolCall, + ChatCompletionInputURL, + ChatCompletionOutput, + ChatCompletionOutputMessage, + ChatCompletionStreamOutput, + ) + from huggingface_hub.errors import HfHubHTTPError + +except ImportError as _import_error: + raise ImportError( + 'Please install `huggingface_hub` to use Hugging Face Inference Providers, ' + 'you can use the `huggingface` optional group — `pip install "pydantic-ai-slim[huggingface]"`' + ) from _import_error + +__all__ = ( + 'HuggingFaceModel', + 'HuggingFaceModelSettings', +) + + +HFSystemPromptRole = Literal['system', 'user'] + +LatestHuggingFaceModelNames = Literal[ + 'deepseek-ai/DeepSeek-R1', + 'meta-llama/Llama-3.3-70B-Instruct', + 'meta-llama/Llama-4-Maverick-17B-128E-Instruct', + 'meta-llama/Llama-4-Scout-17B-16E-Instruct', + 'Qwen/QwQ-32B', + 'Qwen/Qwen2.5-72B-Instruct', + 'Qwen/Qwen3-235B-A22B', + 'Qwen/Qwen3-32B', +] +"""Latest Hugging Face models.""" + + +HuggingFaceModelName = Union[str, LatestHuggingFaceModelNames] +"""Possible Hugging Face model names. + +You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). +""" + + +class HuggingFaceModelSettings(ModelSettings, total=False): + """Settings used for a Hugging Face model request.""" + + # ALL FIELDS MUST BE `huggingface_` PREFIXED SO YOU CAN MERGE THEM WITH OTHER MODELS. + # This class is a placeholder for any future huggingface-specific settings + + +@dataclass(init=False) +class HuggingFaceModel(Model): + """A model that uses Hugging Face Inference Providers. + + Internally, this uses the [HF Python client](https://github.com/huggingface/huggingface_hub) to interact with the API. + + Apart from `__init__`, all methods are private or match those of the base class. + """ + + client: AsyncInferenceClient = field(repr=False) + + _model_name: str = field(repr=False) + _system: str = field(default='huggingface', repr=False) + + def __init__( + self, + model_name: str, + *, + provider: Literal['huggingface'] | Provider[AsyncInferenceClient] = 'huggingface', + ): + """Initialize a Hugging Face model. + + Args: + model_name: The name of the Model to use. You can browse available models [here](https://huggingface.co/models?pipeline_tag=text-generation&inference_provider=all&sort=trending). + provider: The provider to use for Hugging Face Inference Providers. Can be either the string 'huggingface' or an + instance of `Provider[AsyncInferenceClient]`. If not provided, the other parameters will be used. + """ + self._model_name = model_name + self._provider = provider + if isinstance(provider, str): + provider = infer_provider(provider) + self.client = provider.client + + async def request( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> ModelResponse: + check_allow_model_requests() + response = await self._completions_create( + messages, False, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + model_response = self._process_response(response) + model_response.usage.requests = 1 + return model_response + + @asynccontextmanager + async def request_stream( + self, + messages: list[ModelMessage], + model_settings: ModelSettings | None, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterator[StreamedResponse]: + check_allow_model_requests() + response = await self._completions_create( + messages, True, cast(HuggingFaceModelSettings, model_settings or {}), model_request_parameters + ) + yield await self._process_streamed_response(response) + + @property + def model_name(self) -> HuggingFaceModelName: + """The model name.""" + return self._model_name + + @property + def system(self) -> str: + """The system / model provider.""" + return self._system + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[True], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> AsyncIterable[ChatCompletionStreamOutput]: ... + + @overload + async def _completions_create( + self, + messages: list[ModelMessage], + stream: Literal[False], + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput: ... + + async def _completions_create( + self, + messages: list[ModelMessage], + stream: bool, + model_settings: HuggingFaceModelSettings, + model_request_parameters: ModelRequestParameters, + ) -> ChatCompletionOutput | AsyncIterable[ChatCompletionStreamOutput]: + tools = self._get_tools(model_request_parameters) + + if not tools: + tool_choice: Literal['none', 'required', 'auto'] | None = None + elif not model_request_parameters.allow_text_output: + tool_choice = 'required' + else: + tool_choice = 'auto' + + hf_messages = await self._map_messages(messages) + + try: + return await self.client.chat.completions.create( # type: ignore + model=self._model_name, + messages=hf_messages, # type: ignore + tools=tools, + tool_choice=tool_choice or None, + stream=stream, + stop=model_settings.get('stop_sequences', None), + temperature=model_settings.get('temperature', None), + top_p=model_settings.get('top_p', None), + seed=model_settings.get('seed', None), + presence_penalty=model_settings.get('presence_penalty', None), + frequency_penalty=model_settings.get('frequency_penalty', None), + logit_bias=model_settings.get('logit_bias', None), # type: ignore + logprobs=model_settings.get('logprobs', None), + top_logprobs=model_settings.get('top_logprobs', None), + extra_body=model_settings.get('extra_body'), # type: ignore + ) + except aiohttp.ClientResponseError as e: + raise ModelHTTPError( + status_code=e.status, + model_name=self.model_name, + body=e.response_error_payload, # type: ignore + ) from e + except HfHubHTTPError as e: + raise ModelHTTPError( + status_code=e.response.status_code, + model_name=self.model_name, + body=e.response.content, + ) from e + + def _process_response(self, response: ChatCompletionOutput) -> ModelResponse: + """Process a non-streamed response, and prepare a message to return.""" + if response.created: + timestamp = datetime.fromtimestamp(response.created, tz=timezone.utc) + else: + timestamp = _now_utc() + + choice = response.choices[0] + content = choice.message.content + tool_calls = choice.message.tool_calls + + items: list[ModelResponsePart] = [] + + if content is not None: + items.extend(split_content_into_text_and_thinking(content)) + if tool_calls is not None: + for c in tool_calls: + items.append(ToolCallPart(c.function.name, c.function.arguments, tool_call_id=c.id)) + return ModelResponse( + items, + usage=_map_usage(response), + model_name=response.model, + timestamp=timestamp, + vendor_id=response.id, + ) + + async def _process_streamed_response(self, response: AsyncIterable[ChatCompletionStreamOutput]) -> StreamedResponse: + """Process a streamed response, and prepare a streaming response to return.""" + peekable_response = _utils.PeekableAsyncStream(response) + first_chunk = await peekable_response.peek() + if isinstance(first_chunk, _utils.Unset): + raise UnexpectedModelBehavior( # pragma: no cover + 'Streamed response ended without content or tool calls' + ) + + return HuggingFaceStreamedResponse( + _model_name=self._model_name, + _response=peekable_response, + _timestamp=datetime.fromtimestamp(first_chunk.created, tz=timezone.utc), + ) + + def _get_tools(self, model_request_parameters: ModelRequestParameters) -> list[ChatCompletionInputTool]: + tools = [self._map_tool_definition(r) for r in model_request_parameters.function_tools] + if model_request_parameters.output_tools: + tools += [self._map_tool_definition(r) for r in model_request_parameters.output_tools] + return tools + + async def _map_messages( + self, messages: list[ModelMessage] + ) -> list[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + """Just maps a `pydantic_ai.Message` to a `huggingface_hub.ChatCompletionInputMessage`.""" + hf_messages: list[ChatCompletionInputMessage | ChatCompletionOutputMessage] = [] + for message in messages: + if isinstance(message, ModelRequest): + async for item in self._map_user_message(message): + hf_messages.append(item) + elif isinstance(message, ModelResponse): + texts: list[str] = [] + tool_calls: list[ChatCompletionInputToolCall] = [] + for item in message.parts: + if isinstance(item, TextPart): + texts.append(item.content) + elif isinstance(item, ToolCallPart): + tool_calls.append(self._map_tool_call(item)) + elif isinstance(item, ThinkingPart): + # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this, + # please open an issue. The below code is the code to send thinking to the provider. + # texts.append(f'\n{item.content}\n') + pass + else: + assert_never(item) + message_param = ChatCompletionInputMessage(role='assistant') # type: ignore + if texts: + # Note: model responses from this model should only have one text item, so the following + # shouldn't merge multiple texts into one unless you switch models between runs: + message_param['content'] = '\n\n'.join(texts) + if tool_calls: + message_param['tool_calls'] = tool_calls + hf_messages.append(message_param) + else: + assert_never(message) + if instructions := self._get_instructions(messages): + hf_messages.insert(0, ChatCompletionInputMessage(content=instructions, role='system')) # type: ignore + return hf_messages + + @staticmethod + def _map_tool_call(t: ToolCallPart) -> ChatCompletionInputToolCall: + return ChatCompletionInputToolCall.parse_obj_as_instance( # type: ignore + { + 'id': _guard_tool_call_id(t=t), + 'type': 'function', + 'function': { + 'name': t.tool_name, + 'arguments': t.args_as_json_str(), + }, + } + ) + + @staticmethod + def _map_tool_definition(f: ToolDefinition) -> ChatCompletionInputTool: + tool_param: ChatCompletionInputTool = ChatCompletionInputTool.parse_obj_as_instance( # type: ignore + { + 'type': 'function', + 'function': { + 'name': f.name, + 'description': f.description, + 'parameters': f.parameters_json_schema, + }, + } + ) + if f.strict is not None: + tool_param['function']['strict'] = f.strict + return tool_param + + async def _map_user_message( + self, message: ModelRequest + ) -> AsyncIterable[ChatCompletionInputMessage | ChatCompletionOutputMessage]: + for part in message.parts: + if isinstance(part, SystemPromptPart): + yield ChatCompletionInputMessage.parse_obj_as_instance({'role': 'system', 'content': part.content}) # type: ignore + elif isinstance(part, UserPromptPart): + yield await self._map_user_prompt(part) + elif isinstance(part, ToolReturnPart): + yield ChatCompletionOutputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response_str(), + } + ) + elif isinstance(part, RetryPromptPart): + if part.tool_name is None: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + {'role': 'user', 'content': part.model_response()} + ) + else: + yield ChatCompletionInputMessage.parse_obj_as_instance( # type: ignore + { + 'role': 'tool', + 'tool_call_id': _guard_tool_call_id(t=part), + 'content': part.model_response(), + } + ) + else: + assert_never(part) + + @staticmethod + async def _map_user_prompt(part: UserPromptPart) -> ChatCompletionInputMessage: + content: str | list[ChatCompletionInputMessage] + if isinstance(part.content, str): + content = part.content + else: + content = [] + for item in part.content: + if isinstance(item, str): + content.append(ChatCompletionInputMessageChunk(type='text', text=item)) # type: ignore + elif isinstance(item, ImageUrl): + url = ChatCompletionInputURL(url=item.url) # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + elif isinstance(item, BinaryContent): + base64_encoded = base64.b64encode(item.data).decode('utf-8') + if item.is_image: + url = ChatCompletionInputURL(url=f'data:{item.media_type};base64,{base64_encoded}') # type: ignore + content.append(ChatCompletionInputMessageChunk(type='image_url', image_url=url)) # type: ignore + else: # pragma: no cover + raise RuntimeError(f'Unsupported binary content type: {item.media_type}') + elif isinstance(item, AudioUrl): + raise NotImplementedError('AudioUrl is not supported for Hugging Face') + elif isinstance(item, DocumentUrl): + raise NotImplementedError('DocumentUrl is not supported for Hugging Face') + elif isinstance(item, VideoUrl): + raise NotImplementedError('VideoUrl is not supported for Hugging Face') + else: + assert_never(item) + return ChatCompletionInputMessage(role='user', content=content) # type: ignore + + +@dataclass +class HuggingFaceStreamedResponse(StreamedResponse): + """Implementation of `StreamedResponse` for Hugging Face models.""" + + _model_name: str + _response: AsyncIterable[ChatCompletionStreamOutput] + _timestamp: datetime + + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async for chunk in self._response: + self._usage += _map_usage(chunk) + + try: + choice = chunk.choices[0] + except IndexError: + continue + + # Handle the text part of the response + content = choice.delta.content + if content is not None: + yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=content) + + for dtc in choice.delta.tool_calls or []: + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc.index, + tool_name=dtc.function and dtc.function.name, # type: ignore + args=dtc.function and dtc.function.arguments, + tool_call_id=dtc.id, + ) + if maybe_event is not None: + yield maybe_event + + @property + def model_name(self) -> str: + """Get the model name of the response.""" + return self._model_name + + @property + def timestamp(self) -> datetime: + """Get the timestamp of the response.""" + return self._timestamp + + +def _map_usage(response: ChatCompletionOutput | ChatCompletionStreamOutput) -> usage.Usage: + response_usage = response.usage + if response_usage is None: + return usage.Usage() + + return usage.Usage( + request_tokens=response_usage.prompt_tokens, + response_tokens=response_usage.completion_tokens, + total_tokens=response_usage.total_tokens, + details=None, + ) diff --git a/pydantic_ai_slim/pydantic_ai/providers/__init__.py b/pydantic_ai_slim/pydantic_ai/providers/__init__.py index 5e2112ac66..f756120cf4 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/providers/__init__.py @@ -111,6 +111,10 @@ def infer_provider_class(provider: str) -> type[Provider[Any]]: # noqa: C901 from .heroku import HerokuProvider return HerokuProvider + elif provider == 'huggingface': + from .huggingface import HuggingFaceProvider + + return HuggingFaceProvider elif provider == 'github': from .github import GitHubProvider diff --git a/pydantic_ai_slim/pydantic_ai/providers/huggingface.py b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py new file mode 100644 index 0000000000..8afb415914 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/providers/huggingface.py @@ -0,0 +1,88 @@ +from __future__ import annotations as _annotations + +import os +from typing import overload + +from httpx import AsyncClient + +from pydantic_ai.exceptions import UserError + +try: + from huggingface_hub import AsyncInferenceClient +except ImportError as _import_error: # pragma: no cover + raise ImportError( + 'Please install the `huggingface_hub` package to use the HuggingFace provider, ' + "you can use the `huggingface` optional group — `pip install 'pydantic-ai-slim[huggingface]'`" + ) from _import_error + +from . import Provider + + +class HuggingFaceProvider(Provider[AsyncInferenceClient]): + """Provider for Hugging Face.""" + + @property + def name(self) -> str: + return 'huggingface' + + @property + def base_url(self) -> str: + return self.client.model # type: ignore + + @property + def client(self) -> AsyncInferenceClient: + return self._client + + @overload + def __init__(self, *, base_url: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, provider_name: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, base_url: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, hf_client: AsyncInferenceClient, provider_name: str, api_key: str | None = None) -> None: ... + @overload + def __init__(self, *, api_key: str | None = None) -> None: ... + + def __init__( + self, + base_url: str | None = None, + api_key: str | None = None, + hf_client: AsyncInferenceClient | None = None, + http_client: AsyncClient | None = None, + provider_name: str | None = None, + ) -> None: + """Create a new Hugging Face provider. + + Args: + base_url: The base url for the Hugging Face requests. + api_key: The API key to use for authentication, if not provided, the `HF_TOKEN` environment variable + will be used if available. + hf_client: An existing + [`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) + client to use. If not provided, a new instance will be created. + http_client: (currently ignored) An existing `httpx.AsyncClient` to use for making HTTP requests. + provider_name : Name of the provider to use for inference. available providers can be found in the [HF Inference Providers documentation](https://huggingface.co/docs/inference-providers/index#partners). + defaults to "auto", which will select the first available provider for the model, the first of the providers available for the model, sorted by the user's order in https://hf.co/settings/inference-providers. + If `base_url` is passed, then `provider_name` is not used. + """ + api_key = api_key or os.environ.get('HF_TOKEN') + + if api_key is None: + raise UserError( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ) + + if http_client is not None: + raise ValueError('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead.') + + if base_url is not None and provider_name is not None: + raise ValueError('Cannot provide both `base_url` and `provider_name`.') + + if hf_client is None: + self._client = AsyncInferenceClient(api_key=api_key, provider=provider_name, base_url=base_url) # type: ignore + else: + self._client = hf_client diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 6371e0d7c4..2705ca9144 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -69,6 +69,7 @@ anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.19.0"] mistral = ["mistralai>=1.2.5"] bedrock = ["boto3>=1.37.24"] +huggingface = ["huggingface-hub[inference]>=0.33.2"] # Tools duckduckgo = ["ddgs>=9.0.0"] tavily = ["tavily-python>=0.5.0"] diff --git a/pyproject.toml b/pyproject.toml index 18c7853c70..534f156db7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -46,7 +46,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index ce95301d3f..f94f5f0477 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -316,6 +316,11 @@ def openrouter_api_key() -> str: return os.getenv('OPENROUTER_API_KEY', 'mock-api-key') +@pytest.fixture(scope='session') +def huggingface_api_key() -> str: + return os.getenv('HF_TOKEN', 'hf_token') + + @pytest.fixture(scope='session') def heroku_inference_key() -> str: return os.getenv('HEROKU_INFERENCE_KEY', 'mock-api-key') @@ -398,6 +403,7 @@ def model( groq_api_key: str, co_api_key: str, gemini_api_key: str, + huggingface_api_key: str, bedrock_provider: BedrockProvider, ) -> Model: # pragma: lax no cover try: @@ -440,6 +446,14 @@ def model( from pydantic_ai.models.bedrock import BedrockConverseModel return BedrockConverseModel('us.amazon.nova-micro-v1:0', provider=bedrock_provider) + elif request.param == 'huggingface': + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + return HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), + ) else: raise ValueError(f'Unknown model: {request.param}') except ImportError: diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml new file mode 100644 index 0000000000..d8a5ee07e3 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_hf_model_instructions.yaml @@ -0,0 +1,121 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '701' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bd-diYmxjldwbIbFgWNRPBqJ3SEIak" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '560' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Paris + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751470757 + id: chatcmpl-b3936940372c481b8d886e596dc75524 + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 2 + completion_tokens_details: null + prompt_tokens: 26 + prompt_tokens_details: null + total_tokens: 28 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml b/tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml new file mode 100644 index 0000000000..10be947804 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_hf_model_thinking_part.yaml @@ -0,0 +1,291 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen3-235B-A22B?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '470' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"1d6-5wPQfbCXoh8XtBVekhfceCwHN4Y" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 680daa4ac41c05ba341b67d1 + id: Qwen/Qwen3-235B-A22B + inferenceProviderMapping: + fireworks-ai: + providerId: accounts/fireworks/models/qwen3-235b-a22b + status: live + task: conversational + nebius: + providerId: Qwen/Qwen3-235B-A22B + status: live + task: conversational + novita: + providerId: qwen/qwen3-235b-a22b-fp8 + status: live + task: conversational + nscale: + providerId: Qwen/Qwen3-235B-A22B + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '5526' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: "\nOkay, the user is asking how to cross the street safely. Let me break this down step by step. + First, they need to look both ways to check for cars. But wait, should they check left, right, then left again? + I remember that's a common safety tip. They might be in a country where people drive on the right side or the + left, so maybe should I mention that?\n\nAlso, traffic signals and signs are important. What about pedestrian + crossings or traffic lights? Explaining when to walk when the signal is green and the cars have stopped. Oh, right, + sometimes people might not realize to wait for the walk signal. And even when using a crosswalk, you still need + to look both ways because cars might not stop.\n\nDistractions like phones or headphones. Yeah, people often get + hurt because they're looking at their phone while crossing. Should advise them to put away distractions and stay + alert. Kids and elderly folks might need extra care, like holding an adult's hand.\n\nWhat about if there's no + traffic light or crosswalk? Then finding the safest spot with good visibility, maybe near a corner where cars + can see them better. And teaching kids the basics of street safety.\n\nAlso, the confidence aspect—don't rush, + take your time, make eye contact with drivers. And what to do if stuck in the middle? Wait for the next signal. + Oh, and bicycles! In some places, bike lanes cross sidewalks, so being watchful for cyclists too.\n\nWait, should + I structure these points in a numbered list? Start with stopping at the curb, then looking both ways, checking + traffic signals, obeying signs, avoiding distractions, using crosswalks if possible, teaching kids, staying visible, + making eye contact, and what to do if stuck. Maybe add something about not assuming drivers see them and being + cautious.\n\nLet me make sure not to miss any key points. Also, mention that it's safer to cross at intersections. + And maybe a final note about local laws or practices varying by country. Yeah, that covers the main points. I + should present it clearly so it's easy to follow step by step without getting overwhelmed.\n\n\nCrossing + the street safely requires attention, patience, and following key steps. Here's a clear guide:\n\n1. **Stop at + the Curb**: Find a safe spot to pause before stepping onto the road.\n\n2. **Look Both Ways (Left, Right, Then + Left Again!)** \n - **First check left**: Look for oncoming traffic from your left (if driving is on the right + side in your country). \n - **Then check right**: Check for vehicles coming from the right. \n - **Final + glance left**: Recheck the direction of traffic closest to you before stepping off the curb. \n *(Reverse this + order if driving is on the left, as in the UK or Japan.)*\n\n3. **Use Traffic Signals and Crosswalks**: \n - + Wait for the pedestrian \"walk\" signal (green hand or similar). \n - If there’s no signal, only cross once + all vehicles have come to a complete stop and you’ve made eye contact with drivers. \n - Follow any painted + crosswalk lines and stay within them.\n\n4. **Obey Traffic Signs/Lights**: \n - Red/yellow lights mean stop. + Green means it’s safe to start crossing, but still watch for turning vehicles. \n - If the \"don’t walk\" signal + flashes while you’re mid-crossing, finish crossing without rushing.\n\n5. **Avoid Distractions**: \n - Put + away phones, earbuds, or anything that blocks your senses. \n - Keep your head up and stay alert to your surroundings.\n\n6. + **Be Visible and Predictable**: \n - Wear bright/light-colored clothing, especially at night. \n - Walk + (don’t run) and follow the flow of traffic. Avoid sudden changes in direction.\n\n7. **Teach Children Safely**: + \ \n - Hold young children’s hands. \n - Practice the \"stop, look, listen\" rule together. \n - Teach + them to make eye contact with drivers before crossing.\n\n8. **Cross at Intersections When Possible**: \n - + Drivers expect pedestrians at crosswalks and intersections. \n - If no crosswalk exists, choose a spot with + clear visibility (e.g., where you can see around parked cars).\n\n9. **Don’t Assume Drivers See You**: \n - + Even if a car stops, check for other vehicles that might not yield. \n - At night, use a flashlight or phone + light to stay visible.\n\n10. **What to Do if Stuck Mid-Street**: \n - If the light changes before you reach + the other side, stay calm. \n - Stop at the median or safety island and wait for the next signal. \n\n**Bonus + Tip**: In areas with bike lanes, check for cyclists even once you’ve started crossing. In some places, bikes ride + against traffic flow, so look both ways even on one-way streets.\n\n**Local Laws Matter**: Check rules in your + area—e.g., some places require yielding to pedestrians, while others prioritize drivers. Always prioritize your + safety over assumptions.\n\nFollow these steps, and you’ll cross the street confidently and safely every time! + \U0001F6B6♀️ ✅" + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752067065 + id: chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9 + model: Qwen/Qwen3-235B-A22B + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 1090 + completion_tokens_details: null + prompt_tokens: 15 + prompt_tokens_details: null + total_tokens: 1105 + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '9391' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: "\nOkay, the user previously asked how to cross the street, and I gave a detailed answer. Now they're + asking about crossing a river analogously. Let me start by understanding the connection. They want a similar structured + approach but for crossing a river.\n\nFirst, I need to figure out the equivalents between crossing a street and + a river. The original steps included looking both ways, using signals, avoiding distractions, etc. For a river, + physical steps might involve checking the current, choosing a safe spot, maybe using a bridge or boat.\n\nI should + map each street-crossing step to a river scenario. For example, \"stop at the curb\" becomes \"assess the riverbank.\" + Instead of traffic signals, check for ferry schedules or bridge access. Use safety equipment like life jackets + instead of wearing bright clothes.\n\nWait, the user mentioned \"analogously,\" so the structure should mirror + the previous answer but with river-specific actions. Maybe start by pausing to observe the river, checking water + flow instead of traffic. Use bridges as crosswalks and traffic signals. Boating has its own signals, like flags + or lights.\n\nAlso, think about hazards unique to rivers: strong currents, slippery rocks, wildlife. Safety considerations + here would involve knowing how to swim, having a floatation device, not relying on unreliable methods like jumping + on rocks.\n\nNeed to include steps like evaluating the crossing point, checking for bridges or fords, using boats + if necessary, avoiding hazards. Maybe mention time-sensitive factors like weather or flooding, similar to obeying + traffic lights.\n\nI should ensure each point from the previous answer has a parallel. For example, distractions + like phones would be like not paying attention to the river's flow. Visibility could mean wearing bright colors + to be seen on the water.\n\nAlso, consider group scenarios or children, making sure everyone knows the plan. Teaching + kids to cross safely here might involve knowing how to use floatation devices.\n\nI need to make sure the answer + is comprehensive but clear, following the same logical structure as the street crossing answer. Maybe list steps + in order, similar to the original list. Check for completeness: assessment, choosing the method, using proper + equipment, following safety protocols, dealing with emergencies mid-crossing, etc.\n\nWait, the original answer + had 10 points plus a bonus tip. Should mirror that structure? Maybe create a list with parallels. Also, ensure + that analogies are accurate—traffic becomes currents, crosswalks become bridges or fords, traffic signals become + navigational markers or ferry schedules.\n\nAvoiding distractions would involve focusing on the water's movement + and surroundings, not looking at devices. Visibility in water might involve reflective gear or signaling.\n\nAlso, + legal aspects: crossing at authorized spots only, checking if permits are needed. Local hazards or regulations + could be like local traffic laws.\n\nPotential mistakes: Maybe comparing the river's current to traffic is too + direct. Or missing steps like testing the stability of a path across the river, which isn't present in street + crossing. Need to add those unique elements naturally.\n\nAlso, emergencies: If mid-river and something goes wrong, + similar to being stuck on the street. Having a plan, like heading to a safer area or using safety equipment.\n\nAlright, + structure each point as an analogy, then expand with river-specific advice. Make sure the order parallels the + original. Use similar section titles but adapted. Maybe end with a note on local conditions and safety first.\n\n\nCrossing + a river safely shares many principles with crossing a street but involves adapting to the unpredictable nature + of water. Below is an analogy-based guide, mirroring the street-crossing steps while addressing river-specific + challenges:\n\n---\n\n### **1. Stop at the Riverbank: Assess the Situation**\n - **Pause and Observe** \n Like + stopping at the curb, stand back from the edge to evaluate the river’s width, speed, and hazards (swift currents, + rocks, debris, depth). \n - **Check for Bridges/Piers/Crossings** \n Just as intersections prioritize + pedestrian safety, bridges or marked fords exist for safer passage. Use them if accessible. \n\n---\n\n### **2. + Read the River: Look Both Upstream and Downstream** \n - **Scan Both Directions** \n Just as you look + left/right for cars, search **upstream (left if driving is right-hand)** for hazards like floating debris or sudden + surges. Check **downstream (right)** for exit points in case you’re swept away. \n - **Check the Flow** \n + \ Assess current strength: Is it a gentle trickle or a raging torrent? Avoid crossing if water is above knee-deep + or too fast. \n\n---\n\n### **3. Use Safe Routes: Bridges, Ferries, or Designated Fords** \n - **Follow Traffic + Signals → Follow Nautical Rules** \n Wait for ferry schedules, flashing lights (if present), or buoys marking + safe paths. Cross only when signals (like a ferry’s horn) indicate it’s safe. \n - **Choose a Footbridge or + Ferry** \n Bridges eliminate water risks entirely, much like crosswalks. Ferries or boats (with licensed + operators) are safest for wider rivers. \n\n---\n\n### **4. Prioritize Your Path: Know Where to Step or Swim** + \ \n - **Identify Stable Rocks or Shallows** \n If wading, pick a route with flat, secure footing (like + stepping stones) or the shallowest stretch, avoiding slippery algae-covered surfaces. \n - **Test the Current** + \ \n Before fully entering, use a stick or rock to gauge the force of the water. Swift currents can sweep + you off your feet faster than a car can strike. \n\n---\n\n### **5. Avoid Distractions: Focus on the Movement** + \ \n - **Put Away Devices** \n A phone distraction here risks losing balance in the river versus stepping + blindly into traffic. Keep both hands free for stability. \n - **Listen to the River** \n Gurgling or + roaring water warns of hidden holes or rapids—similar to hearing a car engine approaching. \n\n---\n\n### **6. + Be Predictable and Visible: Wear Bright Gear or Floats** \n - **Wear a Life Jacket** \n Like high-visibility + clothing, a life jacket keeps you buoyant and makes you easier for rescuers or boat operators to spot. \n - + **Stick to a Straight Route** \n Zigzagging in water wastes energy and increases the risk of losing balance, + just as darting across lanes on a street invites accidents. \n\n---\n\n### **7. Communicate: Make Eye Contact + with Boaters or Guides** \n - **Signal to Operators** \n In small boats or rafts, wave to catch the attention + of passing vessels (like making eye contact with drivers) to ensure they see you. \n - **Use Hand Signals or + Whistles** \n Agree on emergency signals with your group beforehand (e.g., pointing downstream to signal + danger). \n\n---\n\n### **8. Cross at the Safest Spot: Avoid Mid-River Surprises** \n - **Choose Wide, Slow + Sections** \n Like crossing at intersections, wide shallow areas have gentler currents. Avoid narrows where + water funnels into rapids. \n - **Watch for Hidden Dangers** \n Submerged logs, sudden drop-offs, or hypothermic + cold water can be as lethal as a speeding car. \n\n---\n\n### **9. Don’t Assume Safety: Verify Every Step or + Stroke** \n - **Test Each Footstep** \n Tap the riverbed before transferring weight to avoid stepping + into a hole or loose gravel (like checking for icy patches on a sidewalk). \n - **Swim Only If Trained** \n + \ If the river is too deep to wade, only swim if you know how. Use floatation devices if unsure—similar to + holding an adult’s hand as a child crossing a street. \n\n---\n\n### **10. Mid-River Emergencies: What to Do + if Stuck** \n - **If Struck by Current** \n Stay calm, float on your back with feet downstream (to avoid + head-first collisions), and steer toward eddies or the shore. \n - **If Trapped on a Rock** \n Hug a large + rock and wait for help, like pausing at a median. Don’t risk swimming diagonally across the river’s flow. \n\n---\n\n### + **Bonus Tip: Adapt to Local Conditions** \n - **Research Hazards** \n Some rivers have undertows, wildlife, + or pollution. Check local warnings (like road signs for blind corners or school zones). \n - **Weather Watch** + \ \n Sudden rainstorms can cause flash floods—delay crossing if clouds mass on the horizon. \n\n---\n\nBy + applying street-crossing principles to river navigation—patience, situational awareness, and prioritizing safe + infrastructure—you can minimize risks. Always assume the river is more dangerous than it appears, just as you’d + treat an unfamiliar road. **Safety first, crossing second!** \U0001F30A \U0001F6A4 ⚠️" + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752067094 + id: chatcmpl-35fdec1307634f94a39f7e26f52e12a7 + model: Qwen/Qwen3-235B-A22B + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 1860 + completion_tokens_details: null + prompt_tokens: 691 + prompt_tokens_details: null + total_tokens: 2551 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml b/tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml new file mode 100644 index 0000000000..8b295d4404 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_image_as_binary_content_input.yaml @@ -0,0 +1,106 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-VL-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '293' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"125-DEMuQsKZBCb9/68jW5UsI3Q7x7E" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6797079422990ae89b5aff86 + id: Qwen/Qwen2.5-VL-72B-Instruct + inferenceProviderMapping: + hyperbolic: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '776' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: The fruit in the image is a kiwi. It has been sliced in half, revealing its bright green flesh with small + black seeds arranged in a circular pattern around a white center. The outer skin of the kiwi is fuzzy and brown. + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751986733 + id: chatcmpl-bd957b950cce4d61839e2af25f56f684 + model: Qwen/Qwen2.5-VL-72B-Instruct + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 49 + completion_tokens_details: null + prompt_tokens: 7625 + prompt_tokens_details: null + total_tokens: 7674 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_image_url_input.yaml b/tests/models/cassettes/test_huggingface/test_image_url_input.yaml new file mode 100644 index 0000000000..791a0aede5 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_image_url_input.yaml @@ -0,0 +1,105 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-VL-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '293' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"125-DEMuQsKZBCb9/68jW5UsI3Q7x7E" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6797079422990ae89b5aff86 + id: Qwen/Qwen2.5-VL-72B-Instruct + inferenceProviderMapping: + hyperbolic: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-VL-72B-Instruct + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '612' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you with this image of a potato? + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751983479 + id: chatcmpl-49aa100effab4ca28514d5ccc00d7944 + model: Qwen/Qwen2.5-VL-72B-Instruct + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 15 + completion_tokens_details: null + prompt_tokens: 269 + prompt_tokens_details: null + total_tokens: 284 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml new file mode 100644 index 0000000000..8395c16fc6 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[Qwen-Qwen2.5-72B-Instruct].yaml @@ -0,0 +1,122 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '704' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2c0-CGiQuUurY/UiBTJC7RlRRjJtbZU" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: error + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '693' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you today? Whether you have questions, need help with something specific, or just + want to chat, I'm here to help! + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752050598 + id: chatcmpl-5295b41092674918b860d41f723660cb + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 33 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 63 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml new file mode 100644 index 0000000000..6f9868de9b --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[deepseek-ai-DeepSeek-R1-0528].yaml @@ -0,0 +1,128 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/deepseek-ai/DeepSeek-R1-0528?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '678' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2a6-gQg+B654Px2F2NUtLDU93uSoBDU" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6836db82a3626cb7b5343be8 + id: deepseek-ai/DeepSeek-R1-0528 + inferenceProviderMapping: + fireworks-ai: + providerId: accounts/fireworks/models/deepseek-r1-0528 + status: live + task: conversational + hyperbolic: + providerId: deepseek-ai/DeepSeek-R1-0528 + status: live + task: conversational + nebius: + providerId: deepseek-ai/DeepSeek-R1-0528 + status: live + task: conversational + novita: + providerId: deepseek/deepseek-r1-0528 + status: live + task: conversational + sambanova: + providerId: DeepSeek-R1-0528 + status: live + task: conversational + together: + providerId: deepseek-ai/DeepSeek-R1 + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '1325' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: "\nOkay, the user just said “hello”. A simple greeting. They might be testing if I'm online, starting + a casual chat, or preparing a deeper question. \n\nSince they didn't add context, I'll match their tone—friendly + and open-ended. Short response invites them to lead. Adding the emoji makes it warmer. No need to overthink yet. + \n\nHmm… if they're new, they might need reassurance that I'm responsive. If they're regular users, they're probably + just warming up. Either way, keeping it light feels safe. \n\nWatch for clues in their next message—if they dive + into a topic, they were just being polite before asking. If they reply with small talk, they might want companionship.\n\nHello! + \U0001F60A How can I assist you today?" + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1752050599 + id: chatcmpl-25472217e5b643e0a1f3f20dd44ed2c1 + kv_transfer_params: null + model: deepseek-ai/DeepSeek-R1-0528 + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 165 + completion_tokens_details: null + prompt_tokens: 6 + prompt_tokens_details: null + total_tokens: 171 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml new file mode 100644 index 0000000000..101f8f9e22 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_max_completion_tokens[meta-llama-Llama-3.3-70B-Instruct].yaml @@ -0,0 +1,142 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/meta-llama/Llama-3.3-70B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '1215' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"4bf-2c5rXKFDCLWF+O3TnkXoII8pC2U" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 6745f28f9333dfcc06268b1e + id: meta-llama/Llama-3.3-70B-Instruct + inferenceProviderMapping: + cerebras: + providerId: llama-3.3-70b + status: live + task: conversational + featherless-ai: + providerId: meta-llama/Llama-3.3-70B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/llama-v3p3-70b-instruct + status: live + task: conversational + groq: + providerId: llama-3.3-70b-versatile + status: live + task: conversational + hyperbolic: + providerId: meta-llama/Llama-3.3-70B-Instruct + status: live + task: conversational + nebius: + providerId: meta-llama/Llama-3.3-70B-Instruct-fast + status: live + task: conversational + novita: + providerId: meta-llama/llama-3.3-70b-instruct + status: live + task: conversational + nscale: + providerId: meta-llama/Llama-3.3-70B-Instruct + status: live + task: conversational + ovhcloud: + providerId: Meta-Llama-3_3-70B-Instruct + status: error + task: conversational + sambanova: + providerId: Meta-Llama-3.3-70B-Instruct + status: live + task: conversational + together: + providerId: meta-llama/Llama-3.3-70B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '686' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: '{"type": "function", "name": "print_output", "parameters": {"output": "hello"}}' + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: 128008 + created: 1752050609 + id: chatcmpl-e4e88c8a58b34ea8bd5c47e6265a0de3 + kv_transfer_params: null + model: meta-llama/Llama-3.3-70B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 23 + completion_tokens_details: null + prompt_tokens: 92 + prompt_tokens_details: null + total_tokens: 115 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml new file mode 100644 index 0000000000..6996da0333 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_request_simple_success_with_vcr.yaml @@ -0,0 +1,126 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '800' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"320-IoLwHc4XKGzRoHW0ok1gY7tY/NI" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hf-inference: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: live + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '680' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with + anything specific. + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1749475549 + id: chatcmpl-6050852c70164258bb9bab4e93e2b69c + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 29 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 59 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml b/tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml new file mode 100644 index 0000000000..4025ce48a1 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_request_simple_usage.yaml @@ -0,0 +1,122 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '712' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, + or just want to chat, feel free to let me know! + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751982062 + id: chatcmpl-f366f315c05040fd9c4a505b516bce4b + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 40 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 70 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_simple_completion.yaml b/tests/models/cassettes/test_huggingface/test_simple_completion.yaml new file mode 100644 index 0000000000..a5f1d979ec --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_simple_completion.yaml @@ -0,0 +1,122 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '680' + content-type: + - application/json + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + vary: + - Origin + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + audio: null + content: Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with + anything specific. + function_call: null + reasoning_content: null + refusal: null + role: assistant + tool_calls: [] + stop_reason: null + created: 1751982153 + id: chatcmpl-d445c0d473a84791af2acf356cc00df7 + model: Qwen/Qwen2.5-72B-Instruct-fast + object: chat.completion + prompt_logprobs: null + service_tier: null + system_fingerprint: null + usage: + completion_tokens: 29 + completion_tokens_details: null + prompt_tokens: 30 + prompt_tokens_details: null + total_tokens: 59 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/cassettes/test_huggingface/test_stream_completion.yaml b/tests/models/cassettes/test_huggingface/test_stream_completion.yaml new file mode 100644 index 0000000000..e592d3f271 --- /dev/null +++ b/tests/models/cassettes/test_huggingface/test_stream_completion.yaml @@ -0,0 +1,319 @@ +interactions: +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"Hello"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"!"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" It"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" seems"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" like"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" your"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" message"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" might"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" have"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" been"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" cut"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" off"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" or"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" not"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" fully"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" sent"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"."},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" Could"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" please"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" provide"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" more"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" details"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" so"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" I"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" can"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" assist"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":" better"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":"?"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-da9066b0c0ff4cdbae89c40870e43764","choices":[{"delta":{"content":""},"finish_reason":"stop","index":0,"logprobs":null,"stop_reason":null}],"created":1751980879,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: [DONE] + + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + transfer-encoding: + - chunked + vary: + - Origin + status: + code: 200 + message: OK +- request: + body: null + headers: + accept: + - '*/*' + accept-encoding: + - gzip, deflate + connection: + - keep-alive + method: GET + uri: https://huggingface.co/api/models/Qwen/Qwen2.5-72B-Instruct?expand=inferenceProviderMapping + response: + headers: + access-control-allow-origin: + - https://huggingface.co + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-length: + - '703' + content-type: + - application/json; charset=utf-8 + cross-origin-opener-policy: + - same-origin + etag: + - W/"2bf-bkSLwumMG89/DZCsDWwBvtIEsEs" + referrer-policy: + - strict-origin-when-cross-origin + vary: + - Origin + parsed_body: + _id: 66e81cefd1b1391042d0e47e + id: Qwen/Qwen2.5-72B-Instruct + inferenceProviderMapping: + featherless-ai: + providerId: Qwen/Qwen2.5-72B-Instruct + status: live + task: conversational + fireworks-ai: + providerId: accounts/fireworks/models/qwen2p5-72b-instruct + status: live + task: conversational + hyperbolic: + providerId: Qwen/Qwen2.5-72B-Instruct + status: error + task: conversational + nebius: + providerId: Qwen/Qwen2.5-72B-Instruct-fast + status: live + task: conversational + novita: + providerId: qwen/qwen-2.5-72b-instruct + status: error + task: conversational + together: + providerId: Qwen/Qwen2.5-72B-Instruct-Turbo + status: live + task: conversational + status: + code: 200 + message: OK +- request: + body: null + headers: {} + method: POST + uri: https://router.huggingface.co/nebius/v1/chat/completions + response: + body: + string: |+ + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"","role":"assistant"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"Hello"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"!"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" How"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" can"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" I"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" assist"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" today"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"?"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" Feel"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" free"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" to"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" ask"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" me"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" any"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" questions"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" or"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" let"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" me"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" know"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" if"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" you"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" need"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" help"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" with"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" anything"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":" specific"},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":"."},"finish_reason":null,"index":0,"logprobs":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: {"id":"chatcmpl-dad488ace0b540629381a97ed61f6426","choices":[{"delta":{"content":""},"finish_reason":"stop","index":0,"logprobs":null,"stop_reason":null}],"created":1751980905,"model":"Qwen/Qwen2.5-72B-Instruct-fast","object":"chat.completion.chunk"} + + data: [DONE] + + headers: + access-control-allow-credentials: + - 'true' + access-control-allow-origin: + - '*' + access-control-expose-headers: + - X-Repo-Commit,X-Request-Id,X-Error-Code,X-Error-Message,X-Total-Count,ETag,Link,Accept-Ranges,Content-Range,X-Linked-Size,X-Linked-ETag,X-Xet-Hash + connection: + - keep-alive + content-type: + - text/event-stream; charset=utf-8 + cross-origin-opener-policy: + - same-origin + referrer-policy: + - strict-origin-when-cross-origin + strict-transport-security: + - max-age=31536000; includeSubDomains + transfer-encoding: + - chunked + vary: + - Origin + status: + code: 200 + message: OK +version: 1 +... diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py new file mode 100644 index 0000000000..bc6a7a359d --- /dev/null +++ b/tests/models/test_huggingface.py @@ -0,0 +1,999 @@ +from __future__ import annotations as _annotations + +import json +from collections.abc import Sequence +from dataclasses import asdict, dataclass, field +from datetime import datetime, timezone +from functools import cached_property +from typing import Any, Literal, Union, cast +from unittest.mock import Mock + +import pytest +from inline_snapshot import snapshot +from typing_extensions import TypedDict + +from pydantic_ai import Agent, ModelRetry, UnexpectedModelBehavior +from pydantic_ai.exceptions import ModelHTTPError +from pydantic_ai.messages import ( + AudioUrl, + BinaryContent, + DocumentUrl, + ImageUrl, + ModelRequest, + ModelResponse, + RetryPromptPart, + SystemPromptPart, + TextPart, + ThinkingPart, + ToolCallPart, + ToolReturnPart, + UserPromptPart, + VideoUrl, +) +from pydantic_ai.result import Usage +from pydantic_ai.settings import ModelSettings +from pydantic_ai.tools import RunContext + +from ..conftest import IsDatetime, IsInstance, IsNow, IsStr, raise_if_exception, try_import +from .mock_async_stream import MockAsyncStream + +with try_import() as imports_successful: + import aiohttp + from huggingface_hub import ( + AsyncInferenceClient, + ChatCompletionInputMessage, + ChatCompletionOutput, + ChatCompletionOutputComplete, + ChatCompletionOutputFunctionDefinition, + ChatCompletionOutputMessage, + ChatCompletionOutputToolCall, + ChatCompletionOutputUsage, + ChatCompletionStreamOutput, + ChatCompletionStreamOutputChoice, + ChatCompletionStreamOutputDelta, + ChatCompletionStreamOutputDeltaToolCall, + ChatCompletionStreamOutputFunction, + ChatCompletionStreamOutputUsage, + ) + from huggingface_hub.errors import HfHubHTTPError + + from pydantic_ai.models.huggingface import HuggingFaceModel + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + MockChatCompletion = Union[ChatCompletionOutput, Exception] + MockStreamEvent = Union[ChatCompletionStreamOutput, Exception] + +pytestmark = [ + pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed'), + pytest.mark.anyio, + pytest.mark.filterwarnings('ignore::ResourceWarning'), +] + + +@dataclass +class MockHuggingFace: + completions: MockChatCompletion | Sequence[MockChatCompletion] | None = None + stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] | None = None + index: int = 0 + chat_completion_kwargs: list[dict[str, Any]] = field(default_factory=list) + + @cached_property + def chat(self) -> Any: + completions = type('Completions', (), {'create': self.chat_completions_create}) + return type('Chat', (), {'completions': completions}) + + @classmethod + def create_mock(cls, completions: MockChatCompletion | Sequence[MockChatCompletion]) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(completions=completions)) + + @classmethod + def create_stream_mock( + cls, stream: Sequence[MockStreamEvent] | Sequence[Sequence[MockStreamEvent]] + ) -> AsyncInferenceClient: + return cast(AsyncInferenceClient, cls(stream=stream)) + + async def chat_completions_create( + self, *_args: Any, stream: bool = False, **kwargs: Any + ) -> ChatCompletionOutput | MockAsyncStream[MockStreamEvent]: + self.chat_completion_kwargs.append(kwargs) + if stream or self.stream: + assert self.stream is not None, 'you can only use `stream=True` if `stream` is provided' + if isinstance(self.stream[0], Sequence): + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream[self.index]))) + else: + response = MockAsyncStream(iter(cast(list[MockStreamEvent], self.stream))) + else: + assert self.completions is not None, 'you can only use `stream=False` if `completions` are provided' + if isinstance(self.completions, Sequence): + raise_if_exception(self.completions[self.index]) + response = cast(ChatCompletionOutput, self.completions[self.index]) + else: + raise_if_exception(self.completions) + response = cast(ChatCompletionOutput, self.completions) + self.index += 1 + return response + + +def get_mock_chat_completion_kwargs(hf_client: AsyncInferenceClient) -> list[dict[str, Any]]: + if isinstance(hf_client, MockHuggingFace): + return hf_client.chat_completion_kwargs + else: # pragma: no cover + raise RuntimeError('Not a MockHuggingFace instance') + + +def completion_message( + message: ChatCompletionInputMessage | ChatCompletionOutputMessage, *, usage: ChatCompletionOutputUsage | None = None +) -> ChatCompletionOutput: + choices = [ChatCompletionOutputComplete(finish_reason='stop', index=0, message=message)] # type:ignore + return ChatCompletionOutput.parse_obj_as_instance( # type: ignore + { + 'id': '123', + 'choices': choices, + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion', + 'usage': usage, + } + ) + + +@pytest.mark.vcr() +async def test_simple_completion(allow_model_requests: None, huggingface_api_key: str): + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), + ) + agent = Agent(model) + + result = await agent.run('hello') + assert ( + result.output + == 'Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) + messages = result.all_messages() + request = messages[0] + response = messages[1] + assert request.parts[0].content == 'hello' # type: ignore + assert response == ModelResponse( + parts=[ + TextPart( + content='Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) + ], + usage=Usage(requests=1, request_tokens=30, response_tokens=29, total_tokens=59), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=datetime(2025, 7, 8, 13, 42, 33, tzinfo=timezone.utc), + vendor_id='chatcmpl-d445c0d473a84791af2acf356cc00df7', + ) + + +@pytest.mark.vcr() +async def test_request_simple_usage(allow_model_requests: None, huggingface_api_key: str): + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), + ) + agent = Agent(model) + + result = await agent.run('Hello') + assert ( + result.output + == "Hello! It's great to meet you. How can I assist you today? Whether you have any questions, need some advice, or just want to chat, feel free to let me know!" + ) + assert result.usage() == snapshot(Usage(requests=1, request_tokens=30, response_tokens=40, total_tokens=70)) + + +async def test_request_structured_response( + allow_model_requests: None, +): + tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'final_result', + 'arguments': '{"response": [1, 2, 123]}', + } + ), + 'id': '123', + 'type': 'function', + } + ) + message = ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call], + } + ) + c = completion_message(message) + + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', hf_client=mock_client, api_key='x'), + ) + agent = Agent(model, output_type=list[int]) + + result = await agent.run('Hello') + assert result.output == [1, 2, 123] + messages = result.all_messages() + assert messages[0].parts[0].content == 'Hello' # type: ignore + assert messages[1] == ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args='{"response": [1, 2, 123]}', + tool_call_id='123', + ) + ], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, tzinfo=timezone.utc), + vendor_id='123', + ) + + +async def test_stream_completion(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world', finish_reason='stop')] + mock_client = MockHuggingFace.create_stream_mock(stream) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + async with agent.run_stream('') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + + +async def test_multiple_stream_calls(allow_model_requests: None): + stream = [ + [text_chunk('first '), text_chunk('call', finish_reason='stop')], + [text_chunk('second '), text_chunk('call', finish_reason='stop')], + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + async with agent.run_stream('first') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['first ', 'first call']) + + async with agent.run_stream('second') as result: + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['second ', 'second call']) + + +async def test_request_tool_call(allow_model_requests: None): + tool_call_1 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "San Fransisco"}', + } + ), + 'id': '1', + 'type': 'function', + } + ) + usage_1 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 1, + 'completion_tokens': 1, + 'total_tokens': 2, + } + ) + tool_call_2 = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'get_location', + 'arguments': '{"loc_name": "London"}', + } + ), + 'id': '2', + 'type': 'function', + } + ) + usage_2 = ChatCompletionOutputUsage.parse_obj_as_instance( # type:ignore + { + 'prompt_tokens': 2, + 'completion_tokens': 1, + 'total_tokens': 3, + } + ) + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_1], + } + ), + usage=usage_1, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call_2], + } + ), + usage=usage_2, + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': 'final response', + 'role': 'assistant', + } + ), + ), + ] + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model, system_prompt='this is the system prompt') + + @agent.tool_plain + async def get_location(loc_name: str) -> str: + if loc_name == 'London': + return json.dumps({'lat': 51, 'lng': 0}) + else: + raise ModelRetry('Wrong location, please try again') + + result = await agent.run('Hello') + assert result.output == 'final response' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + SystemPromptPart(content='this is the system prompt', timestamp=IsNow(tz=timezone.utc)), + UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc)), + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "San Fransisco"}', + tool_call_id='1', + ) + ], + usage=Usage(requests=1, request_tokens=1, response_tokens=1, total_tokens=2), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Wrong location, please try again', + tool_name='get_location', + tool_call_id='1', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_location', + args='{"loc_name": "London"}', + tool_call_id='2', + ) + ], + usage=Usage(requests=1, request_tokens=2, response_tokens=1, total_tokens=3), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='{"lat": 51, "lng": 0}', + tool_call_id='2', + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='final response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + + +FinishReason = Literal['stop', 'length', 'tool_calls', 'content_filter', 'function_call'] + + +def chunk( + delta: list[ChatCompletionStreamOutputDelta], finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return ChatCompletionStreamOutput.parse_obj_as_instance( # type: ignore + { + 'id': 'x', + 'choices': [ + ChatCompletionStreamOutputChoice(index=index, delta=delta, finish_reason=finish_reason) # type: ignore + for index, delta in enumerate(delta) + ], + 'created': 1704067200, # 2024-01-01 + 'model': 'hf-model', + 'object': 'chat.completion.chunk', + 'usage': ChatCompletionStreamOutputUsage(completion_tokens=1, prompt_tokens=2, total_tokens=3), # type: ignore + } + ) + + +def text_chunk(text: str, finish_reason: FinishReason | None = None) -> ChatCompletionStreamOutput: + return chunk([ChatCompletionStreamOutputDelta(content=text, role='assistant')], finish_reason=finish_reason) # type: ignore + + +async def test_stream_text(allow_model_requests: None): + stream = [text_chunk('hello '), text_chunk('world'), chunk([])] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +async def test_stream_text_finish_reason(allow_model_requests: None): + stream = [ + text_chunk('hello '), + text_chunk('world'), + text_chunk('.', finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot( + ['hello ', 'hello world', 'hello world.'] + ) + assert result.is_complete + + +def struc_chunk( + tool_name: str | None, tool_arguments: str | None, finish_reason: FinishReason | None = None +) -> ChatCompletionStreamOutput: + return chunk( + [ + ChatCompletionStreamOutputDelta.parse_obj_as_instance( # type: ignore + { + 'role': 'assistant', + 'tool_calls': [ + ChatCompletionStreamOutputDeltaToolCall.parse_obj_as_instance( # type: ignore + { + 'index': 0, + 'function': ChatCompletionStreamOutputFunction.parse_obj_as_instance( # type: ignore + { + 'name': tool_name, + 'arguments': tool_arguments, + } + ), + } + ) + ], + } + ), + ], + finish_reason=finish_reason, + ) + + +class MyTypedDict(TypedDict, total=False): + first: str + second: str + + +async def test_stream_structured(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant', tool_calls=[])]), # type: ignore + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk('final_result', None), + chunk( + [ + ChatCompletionStreamOutputDelta( + role='assistant', # type: ignore + tool_calls=[ # type: ignore + ChatCompletionStreamOutputDeltaToolCall(id='0', type='function', index=0, function=None) # type: ignore + ], + ) + ] + ), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + chunk([]), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {}, + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=20, response_tokens=10, total_tokens=30)) + # double check usage matches stream count + assert result.usage().response_tokens == len(stream) + + +async def test_stream_structured_finish_reason(allow_model_requests: None): + stream = [ + struc_chunk('final_result', None), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + struc_chunk(None, None, finish_reason='stop'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m, output_type=MyTypedDict) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [dict(c) async for c in result.stream(debounce_by=None)] == snapshot( + [ + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert result.is_complete + + +async def test_no_content(allow_model_requests: None): + stream = [ + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + chunk([ChatCompletionStreamOutputDelta(role='assistant')]), # type: ignore + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m, output_type=MyTypedDict) + + with pytest.raises(UnexpectedModelBehavior, match='Received empty model response'): + async with agent.run_stream(''): + pass + + +async def test_no_delta(allow_model_requests: None): + stream = [ + chunk([]), + text_chunk('hello '), + text_chunk('world'), + ] + mock_client = MockHuggingFace.create_stream_mock(stream) + m = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream_text(debounce_by=None)] == snapshot(['hello ', 'hello world']) + assert result.is_complete + assert result.usage() == snapshot(Usage(requests=1, request_tokens=6, response_tokens=3, total_tokens=9)) + + +@pytest.mark.vcr() +async def test_image_url_input(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-VL-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), + ) + agent = Agent(m) + + result = await agent.run( + [ + 'hello', + ImageUrl(url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg'), + ] + ) + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content=[ + 'hello', + ImageUrl( + url='https://t3.ftcdn.net/jpg/00/85/79/92/360_F_85799278_0BBGV9OAdQDTLnKwAPBCcg1J7QtiieJY.jpg' + ), + ], + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Hello! How can I assist you with this image of a potato?')], + usage=Usage(requests=1, request_tokens=269, response_tokens=15, total_tokens=284), + model_name='Qwen/Qwen2.5-VL-72B-Instruct', + timestamp=datetime(2025, 7, 8, 14, 4, 39, tzinfo=timezone.utc), + vendor_id='chatcmpl-49aa100effab4ca28514d5ccc00d7944', + ), + ] + ) + + +@pytest.mark.vcr() +async def test_image_as_binary_content_input( + allow_model_requests: None, image_content: BinaryContent, huggingface_api_key: str +): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-VL-72B-Instruct', + provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key), + ) + agent = Agent(m) + result = await agent.run(['What fruit is in the image?', image_content]) + assert result.output == snapshot( + 'The fruit in the image is a kiwi. It has been sliced in half, revealing its bright green flesh with small black seeds arranged in a circular pattern around a white center. The outer skin of the kiwi is fuzzy and brown.' + ) + + +def test_model_status_error(allow_model_requests: None) -> None: + error = HfHubHTTPError(message='test_error', response=Mock(status_code=500, content={'error': 'test error'})) + mock_client = MockHuggingFace.create_mock(error) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + with pytest.raises(ModelHTTPError) as exc_info: + agent.run_sync('hello') + assert str(exc_info.value) == snapshot("status_code: 500, model_name: not_a_model, body: {'error': 'test error'}") + + +@pytest.mark.vcr() +async def test_request_simple_success_with_vcr(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) + ) + agent = Agent(m) + result = await agent.run('hello') + assert result.output == snapshot( + 'Hello! How can I assist you today? Feel free to ask me any questions or let me know if you need help with anything specific.' + ) + + +@pytest.mark.vcr() +async def test_hf_model_instructions(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen2.5-72B-Instruct', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) + ) + + def simple_instructions(ctx: RunContext): + return 'You are a helpful assistant.' + + agent = Agent(m, instructions=simple_instructions) + + result = await agent.run('What is the capital of France?') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[UserPromptPart(content='What is the capital of France?', timestamp=IsDatetime())], + instructions='You are a helpful assistant.', + ), + ModelResponse( + parts=[TextPart(content='Paris')], + usage=Usage(requests=1, request_tokens=26, response_tokens=2, total_tokens=28), + model_name='Qwen/Qwen2.5-72B-Instruct-fast', + timestamp=IsDatetime(), + vendor_id='chatcmpl-b3936940372c481b8d886e596dc75524', + ), + ] + ) + + +@pytest.mark.parametrize( + 'model_name', ['Qwen/Qwen2.5-72B-Instruct', 'deepseek-ai/DeepSeek-R1-0528', 'meta-llama/Llama-3.3-70B-Instruct'] +) +@pytest.mark.vcr() +async def test_max_completion_tokens(allow_model_requests: None, model_name: str, huggingface_api_key: str): + m = HuggingFaceModel(model_name, provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key)) + agent = Agent(m, model_settings=ModelSettings(max_tokens=100)) + + result = await agent.run('hello') + assert result.output == IsStr() + + +def test_system_property(): + model = HuggingFaceModel('some-model', provider=HuggingFaceProvider(hf_client=Mock(), api_key='x')) + assert model.system == 'huggingface' + + +async def test_model_client_response_error(allow_model_requests: None) -> None: + request_info = Mock(spec=aiohttp.RequestInfo) + request_info.url = 'http://test.com' + request_info.method = 'POST' + request_info.headers = {} + request_info.real_url = 'http://test.com' + error = aiohttp.ClientResponseError(request_info, history=(), status=400, message='Bad Request') + error.response_error_payload = {'error': 'test error'} # type: ignore + + mock_client = MockHuggingFace.create_mock(error) + m = HuggingFaceModel('not_a_model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(m) + with pytest.raises(ModelHTTPError) as exc_info: + await agent.run('hello') + assert str(exc_info.value) == snapshot("status_code: 400, model_name: not_a_model, body: {'error': 'test error'}") + + +async def test_process_response_no_created_timestamp(allow_model_requests: None): + c = completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'response', 'role': 'assistant'}), # type: ignore + ) + c.created = None # type: ignore + + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel( + 'test-model', + provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'), + ) + agent = Agent(model) + result = await agent.run('Hello') + messages = result.all_messages() + response_message = messages[1] + assert isinstance(response_message, ModelResponse) + assert response_message.timestamp == IsNow(tz=timezone.utc) + + +async def test_retry_prompt_without_tool_name(allow_model_requests: None): + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'invalid-response', 'role': 'assistant'}) # type: ignore + ), + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance({'content': 'final-response', 'role': 'assistant'}) # type: ignore + ), + ] + + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel( + 'test-model', + provider=HuggingFaceProvider(hf_client=mock_client, api_key='x'), + ) + agent = Agent(model) + + @agent.output_validator + def response_validator(value: str) -> str: + if value == 'invalid-response': + raise ModelRetry('Response is invalid') + return value + + result = await agent.run('Hello') + assert result.output == 'final-response' + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='Hello', timestamp=IsNow(tz=timezone.utc))]), + ModelResponse( + parts=[TextPart(content='invalid-response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ModelRequest( + parts=[ + RetryPromptPart( + content='Response is invalid', + tool_name=None, + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='final-response')], + usage=Usage(requests=1), + model_name='hf-model', + timestamp=datetime(2024, 1, 1, 0, 0, tzinfo=timezone.utc), + vendor_id='123', + ), + ] + ) + kwargs = get_mock_chat_completion_kwargs(mock_client)[1] + messages = kwargs['messages'] + assert {k: v for k, v in asdict(messages[-2]).items() if v is not None} == { + 'role': 'assistant', + 'content': 'invalid-response', + } + assert {k: v for k, v in asdict(messages[-1]).items() if v is not None} == { + 'role': 'user', + 'content': 'Validation feedback:\nResponse is invalid\n\nFix the errors and try again.', + } + + +async def test_thinking_part_in_history(allow_model_requests: None): + c = completion_message(ChatCompletionOutputMessage(content='response', role='assistant')) # type: ignore + mock_client = MockHuggingFace.create_mock(c) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + messages = [ + ModelRequest(parts=[UserPromptPart(content='request')]), + ModelResponse( + parts=[ + TextPart(content='thought 1'), + ThinkingPart(content='this should be ignored'), + TextPart(content='thought 2'), + ], + model_name='hf-model', + timestamp=datetime.now(timezone.utc), + ), + ] + + await agent.run('another request', message_history=messages) + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + sent_messages = kwargs['messages'] + assert [{k: v for k, v in asdict(m).items() if v is not None} for m in sent_messages] == snapshot( + [ + {'content': 'request', 'role': 'user'}, + {'content': 'thought 1\n\nthought 2', 'role': 'assistant'}, + {'content': 'another request', 'role': 'user'}, + ] + ) + + +@pytest.mark.parametrize('strict', [True, False, None]) +async def test_tool_strict_mode(allow_model_requests: None, strict: bool | None): + tool_call = ChatCompletionOutputToolCall.parse_obj_as_instance( # type:ignore + { + 'function': ChatCompletionOutputFunctionDefinition.parse_obj_as_instance( # type:ignore + { + 'name': 'my_tool', + 'arguments': '{"x": 42}', + } + ), + 'id': '1', + 'type': 'function', + } + ) + responses = [ + completion_message( + ChatCompletionOutputMessage.parse_obj_as_instance( # type:ignore + { + 'content': None, + 'role': 'assistant', + 'tool_calls': [tool_call], + } + ) + ), + completion_message(ChatCompletionOutputMessage(content='final response', role='assistant')), # type: ignore + ] + mock_client = MockHuggingFace.create_mock(responses) + model = HuggingFaceModel('hf-model', provider=HuggingFaceProvider(hf_client=mock_client, api_key='x')) + agent = Agent(model) + + @agent.tool_plain(strict=strict) + def my_tool(x: int) -> int: + return x + + result = await agent.run('hello') + assert result.output == 'final response' + + kwargs = get_mock_chat_completion_kwargs(mock_client)[0] + tools = kwargs['tools'] + if strict is not None: + assert tools[0]['function']['strict'] is strict + else: + assert 'strict' not in tools[0]['function'] + + +@pytest.mark.parametrize( + 'content_item, error_message', + [ + (AudioUrl(url='url'), 'AudioUrl is not supported for Hugging Face'), + (DocumentUrl(url='url'), 'DocumentUrl is not supported for Hugging Face'), + (VideoUrl(url='url'), 'VideoUrl is not supported for Hugging Face'), + ], +) +async def test_unsupported_media_types(allow_model_requests: None, content_item: Any, error_message: str): + model = HuggingFaceModel( + 'Qwen/Qwen2.5-VL-72B-Instruct', + provider=HuggingFaceProvider(api_key='x'), + ) + agent = Agent(model) + + with pytest.raises(NotImplementedError, match=error_message): + await agent.run(['hello', content_item]) + + +@pytest.mark.vcr() +async def test_hf_model_thinking_part(allow_model_requests: None, huggingface_api_key: str): + m = HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) + ) + agent = Agent(m) + + result = await agent.run('How do I cross the street?') + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), + ModelResponse( + parts=[ + IsInstance(ThinkingPart), + IsInstance(TextPart), + ], + usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + model_name='Qwen/Qwen3-235B-A22B', + timestamp=IsDatetime(), + vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + ), + ] + ) + + result = await agent.run( + 'Considering the way to cross the street, analogously, how do I cross the river?', + model=HuggingFaceModel( + 'Qwen/Qwen3-235B-A22B', provider=HuggingFaceProvider(provider_name='nebius', api_key=huggingface_api_key) + ), + message_history=result.all_messages(), + ) + assert result.all_messages() == snapshot( + [ + ModelRequest(parts=[UserPromptPart(content='How do I cross the street?', timestamp=IsDatetime())]), + ModelResponse( + parts=[ + IsInstance(ThinkingPart), + IsInstance(TextPart), + ], + usage=Usage(requests=1, request_tokens=15, response_tokens=1090, total_tokens=1105), + model_name='Qwen/Qwen3-235B-A22B', + timestamp=IsDatetime(), + vendor_id='chatcmpl-957db61fe60d4440bcfe1f11f2c5b4b9', + ), + ModelRequest( + parts=[ + UserPromptPart( + content='Considering the way to cross the street, analogously, how do I cross the river?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + IsInstance(ThinkingPart), + TextPart(content=IsStr()), + ], + usage=Usage(requests=1, request_tokens=691, response_tokens=1860, total_tokens=2551), + model_name='Qwen/Qwen3-235B-A22B', + timestamp=IsDatetime(), + vendor_id='chatcmpl-35fdec1307634f94a39f7e26f52e12a7', + ), + ] + ) diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index 52a3397a4f..db6f22cd8d 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -16,6 +16,7 @@ from pydantic_ai.models.cohere import CohereModelName from pydantic_ai.models.gemini import GeminiModelName from pydantic_ai.models.groq import GroqModelName + from pydantic_ai.models.huggingface import HuggingFaceModelName from pydantic_ai.models.mistral import MistralModelName from pydantic_ai.models.openai import OpenAIModelName @@ -54,6 +55,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: ] bedrock_names = [f'bedrock:{n}' for n in get_model_names(BedrockModelName)] deepseek_names = ['deepseek:deepseek-chat', 'deepseek:deepseek-reasoner'] + huggingface_names = [f'huggingface:{n}' for n in get_model_names(HuggingFaceModelName)] heroku_names = get_heroku_model_names() extra_names = ['test'] @@ -66,6 +68,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: + openai_names + bedrock_names + deepseek_names + + huggingface_names + heroku_names + extra_names ) diff --git a/tests/providers/test_huggingface.py b/tests/providers/test_huggingface.py new file mode 100644 index 0000000000..c9570a54dc --- /dev/null +++ b/tests/providers/test_huggingface.py @@ -0,0 +1,142 @@ +from __future__ import annotations as _annotations + +import re +from unittest.mock import MagicMock, Mock, patch + +import httpx +import pytest + +from pydantic_ai.exceptions import UserError + +from ..conftest import TestEnv, try_import + +with try_import() as imports_successful: + from huggingface_hub import AsyncInferenceClient + + from pydantic_ai.providers.huggingface import HuggingFaceProvider + + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='huggingface_hub not installed') + + +def test_huggingface_provider(): + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(api_key='api-key', hf_client=hf_client) + assert provider.name == 'huggingface' + assert isinstance(provider.client, AsyncInferenceClient) + assert provider.client.token == 'api-key' + + +def test_huggingface_provider_need_api_key(env: TestEnv) -> None: + env.remove('HF_TOKEN') + with pytest.raises( + UserError, + match=re.escape( + 'Set the `HF_TOKEN` environment variable or pass it via `HuggingFaceProvider(api_key=...)`' + 'to use the HuggingFace provider.' + ), + ): + HuggingFaceProvider() + + +def test_huggingface_provider_pass_http_client() -> None: + http_client = httpx.AsyncClient() + with pytest.raises( + ValueError, + match=re.escape('`http_client` is ignored for HuggingFace provider, please use `hf_client` instead'), + ): + HuggingFaceProvider(http_client=http_client, api_key='api-key') # type: ignore + + +def test_huggingface_provider_pass_hf_client() -> None: + hf_client = AsyncInferenceClient(api_key='api-key') + provider = HuggingFaceProvider(hf_client=hf_client, api_key='api-key') + assert provider.client == hf_client + + +def test_hf_provider_with_base_url() -> None: + # Test with environment variable for base_url + provider = HuggingFaceProvider( + hf_client=AsyncInferenceClient(base_url='https://router.huggingface.co/nebius/v1'), api_key='test-api-key' + ) + assert provider.base_url == 'https://router.huggingface.co/nebius/v1' + + +def test_huggingface_provider_properties(): + mock_client = Mock(spec=AsyncInferenceClient) + mock_client.model = 'test-model' + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + assert provider.name == 'huggingface' + assert provider.client is mock_client + + +def test_huggingface_provider_init_api_key_error(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('HF_TOKEN', raising=False) + with pytest.raises(UserError, match='Set the `HF_TOKEN` environment variable'): + HuggingFaceProvider() + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_api_key_from_env( + MockAsyncInferenceClient: MagicMock, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv('HF_TOKEN', 'env-key') + HuggingFaceProvider() + MockAsyncInferenceClient.assert_called_with(api_key='env-key', provider=None, base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_api_key_from_arg( + MockAsyncInferenceClient: MagicMock, monkeypatch: pytest.MonkeyPatch +): + monkeypatch.setenv('HF_TOKEN', 'env-key') + HuggingFaceProvider(api_key='arg-key') + MockAsyncInferenceClient.assert_called_with(api_key='arg-key', provider=None, base_url=None) + + +def test_huggingface_provider_init_http_client_error(): + with pytest.raises(ValueError, match='`http_client` is ignored'): + HuggingFaceProvider(api_key='key', http_client=Mock()) # type: ignore[call-overload] + + +def test_huggingface_provider_init_base_url_and_provider_name_error(): + with pytest.raises(ValueError, match='Cannot provide both `base_url` and `provider_name`'): + HuggingFaceProvider(api_key='key', base_url='url', provider_name='provider') # type: ignore[call-overload] + + +def test_huggingface_provider_init_with_hf_client(): + mock_client = Mock(spec=AsyncInferenceClient) + provider = HuggingFaceProvider(hf_client=mock_client, api_key='key') + assert provider.client is mock_client + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_without_hf_client(MockAsyncInferenceClient: MagicMock): + provider = HuggingFaceProvider(api_key='key') + assert provider.client is MockAsyncInferenceClient.return_value + MockAsyncInferenceClient.assert_called_with(api_key='key', provider=None, base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_with_provider_name(MockAsyncInferenceClient: MagicMock): + HuggingFaceProvider(api_key='key', provider_name='test-provider') + MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider='test-provider', base_url=None) + + +@patch('pydantic_ai.providers.huggingface.AsyncInferenceClient') +def test_huggingface_provider_init_with_base_url(MockAsyncInferenceClient: MagicMock): + HuggingFaceProvider(api_key='key', base_url='test-url') + MockAsyncInferenceClient.assert_called_once_with(api_key='key', provider=None, base_url='test-url') + + +def test_huggingface_provider_init_api_key_is_none(monkeypatch: pytest.MonkeyPatch): + monkeypatch.delenv('HF_TOKEN', raising=False) + with pytest.raises(UserError): + HuggingFaceProvider(api_key=None) + + +def test_huggingface_provider_base_url(): + mock_client = Mock(spec=AsyncInferenceClient) + mock_client.model = 'test-model' + provider = HuggingFaceProvider(hf_client=mock_client, api_key='test-api-key') + assert provider.base_url == 'test-model' diff --git a/tests/test_cli.py b/tests/test_cli.py index 024116249c..8efc0da005 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -144,6 +144,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'cohere', 'deepseek', 'heroku', + 'huggingface', ) models = {line.strip().split(' ')[0] for line in output[3:]} for provider in providers: diff --git a/tests/test_examples.py b/tests/test_examples.py index c1dda22a49..5735a11ffa 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -149,6 +149,7 @@ def print(self, *args: Any, **kwargs: Any) -> None: env.set('CO_API_KEY', 'testing') env.set('MISTRAL_API_KEY', 'testing') env.set('ANTHROPIC_API_KEY', 'testing') + env.set('HF_TOKEN', 'hf_testing') env.set('AWS_ACCESS_KEY_ID', 'testing') env.set('AWS_SECRET_ACCESS_KEY', 'testing') env.set('AWS_DEFAULT_REGION', 'us-east-1') diff --git a/uv.lock b/uv.lock index 2b2c5c2c3b..bea12cf7af 100644 --- a/uv.lock +++ b/uv.lock @@ -829,16 +829,16 @@ wheels = [ [[package]] name = "ddgs" -version = "9.0.0" +version = "9.2.3" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "lxml" }, { name = "primp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1f/08/0e84549a1d7d5950573f73d7bc5d36f2a00f92ad8e644b59066afd430a6f/ddgs-9.0.0.tar.gz", hash = "sha256:53b47c74a8060457cb02cbb64acdf59655d799ce8e0934e945bcd878fcab3a7f", size = 21795, upload-time = "2025-07-06T15:43:50.862Z" } +sdist = { url = "https://files.pythonhosted.org/packages/57/e1/8154854084b24908ec782f1c2713a66b205bdcd2b20a9bc3ce274afccc24/ddgs-9.2.3.tar.gz", hash = "sha256:5ec4e0bf0a9055a991c958695b1c0194c2511d254449ab88eb874297879ed1a5", size = 26553, upload-time = "2025-07-14T17:17:24.232Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/e5/05/bd3ed9a28212b313f5678533152c4d79fbc386e44245ca5eed426d75f019/ddgs-9.0.0-py3-none-any.whl", hash = "sha256:5dd11d666d6caf1cfdbd341579637bb670c4b2f41df5413b76705519d8e7a22c", size = 17944, upload-time = "2025-07-06T15:43:49.564Z" }, + { url = "https://files.pythonhosted.org/packages/1d/af/d42b3f4eff55cdcddf8b33631be602e40d63d7cf0cffcf15503166a46b22/ddgs-9.2.3-py3-none-any.whl", hash = "sha256:4b658edf52db3bfe80c12492077e7cc9d39312b0dbb03f8669753ac1313d3784", size = 30148, upload-time = "2025-07-14T17:17:22.969Z" }, ] [[package]] @@ -942,16 +942,16 @@ wheels = [ [[package]] name = "duckduckgo-search" -version = "8.1.1" +version = "7.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "click" }, { name = "lxml" }, { name = "primp" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/10/ef/07791a05751e6cc9de1dd49fb12730259ee109b18e6d097e25e6c32d5617/duckduckgo_search-8.1.1.tar.gz", hash = "sha256:9da91c9eb26a17e016ea1da26235d40404b46b0565ea86d75a9f78cc9441f935", size = 22868, upload-time = "2025-07-06T15:30:59.73Z" } +sdist = { url = "https://files.pythonhosted.org/packages/17/a8/18404f6525aefa80290afa920ed76fbab16472f19015fdb957b7113f3a9e/duckduckgo_search-7.5.0.tar.gz", hash = "sha256:3e28dc5ec9188efa3a7c8532aa05aaf03bb34b79370855760abd55e6051ff79b", size = 24657, upload-time = "2025-02-24T14:50:49.356Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/db/72/c027b3b488b1010cf71670032fcf7e681d44b81829d484bb04e31a949a8d/duckduckgo_search-8.1.1-py3-none-any.whl", hash = "sha256:f48adbb06626ee05918f7e0cef3a45639e9939805c4fc179e68c48a12f1b5062", size = 18932, upload-time = "2025-07-06T15:30:58.339Z" }, + { url = "https://files.pythonhosted.org/packages/75/21/fc2c821a2c92c021f8f8adf9fb36235d1b49525b7cd953e85624296aab94/duckduckgo_search-7.5.0-py3-none-any.whl", hash = "sha256:6a2d3f12ae29b3e076cd43be61f5f73cd95261e0a0f318fe0ad3648d7a5dff03", size = 20238, upload-time = "2025-02-24T14:50:48.179Z" }, ] [[package]] @@ -992,16 +992,17 @@ wheels = [ [[package]] name = "fasta2a" -version = "0.4.1" +version = "0.5.0" source = { registry = "https://pypi.org/simple" } dependencies = [ + { name = "eval-type-backport", marker = "python_full_version < '3.10'" }, { name = "opentelemetry-api" }, { name = "pydantic" }, { name = "starlette" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/1e/65/3728453396e5efa6166cf58b32e5aef7dabeba438c8bb20e1c9461fceaed/fasta2a-0.4.1.tar.gz", hash = "sha256:2c664d572480662a73201485ce0f909d607d5d28ba33409646454cef5d0645ed", size = 13966, upload-time = "2025-07-10T08:18:55.859Z" } +sdist = { url = "https://files.pythonhosted.org/packages/5d/2a/f9d212026bdc74068ef9aef493a2b37ce0d4201694d158180759e07489b5/fasta2a-0.5.0.tar.gz", hash = "sha256:0bca45f675fb3354ae6cd0e6dd0be1d504ee135b8e802b4058fb3485521f61e9", size = 1436123, upload-time = "2025-07-10T16:31:01.502Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/8d/1e/e85fcd71af3a3f6c8262c89027a5c5e6c3faf348b1f9101b79d996c801df/fasta2a-0.4.1-py3-none-any.whl", hash = "sha256:f0b4a8162bd7fc9a363ef3724c395c5cb97e87d9d03769b9faaf675389c7fdfb", size = 16934, upload-time = "2025-07-10T08:18:45.014Z" }, + { url = "https://files.pythonhosted.org/packages/c5/08/d25f303013a04e2bec68ed97c4f4f85ad9c178fc582e8e4345147fd141fb/fasta2a-0.5.0-py3-none-any.whl", hash = "sha256:806f4bbd6cd2858ca631d47e75f3bbf4746ff0752ccca38edbfe85930c4ffbe2", size = 25198, upload-time = "2025-07-10T16:30:59.938Z" }, ] [[package]] @@ -1195,7 +1196,7 @@ wheels = [ [[package]] name = "google-genai" -version = "1.24.0" +version = "1.25.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -1207,9 +1208,9 @@ dependencies = [ { name = "typing-extensions" }, { name = "websockets" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/8d/cf/37ac8cd4752e28e547b8a52765fe48a2ada2d0d286ea03f46e4d8c69ff4f/google_genai-1.24.0.tar.gz", hash = "sha256:bc896e30ad26d05a2af3d17c2ba10ea214a94f1c0cdb93d5c004dc038774e75a", size = 226740, upload-time = "2025-07-01T22:14:24.365Z" } +sdist = { url = "https://files.pythonhosted.org/packages/7f/59/c9b9148c8702b60253f5a251f16ae436534c5d4362da193c9db05ac9858c/google_genai-1.25.0.tar.gz", hash = "sha256:a08a79c819a5d949d9948cd372e36e512bf85cd28158994daaa36d0ec4cb2b02", size = 228141, upload-time = "2025-07-09T20:53:47.885Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/30/28/a35f64fc02e599808101617a21d447d241dadeba2aac1f4dc2d1179b8218/google_genai-1.24.0-py3-none-any.whl", hash = "sha256:98be8c51632576289ecc33cd84bcdaf4356ef0bef04ac7578660c49175af22b9", size = 226065, upload-time = "2025-07-01T22:14:23.177Z" }, + { url = "https://files.pythonhosted.org/packages/f6/ec/149f3d49b56cf848142071772aabb1c290b535bd9b5327a5dfccf1d00332/google_genai-1.25.0-py3-none-any.whl", hash = "sha256:fb5cee79b9a0a1b2afd5cfdf279099ecebd186551eefcaa6ec0c6016244e6138", size = 226847, upload-time = "2025-07-09T20:53:46.532Z" }, ] [[package]] @@ -1340,6 +1341,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d0/9e/984486f2d0a0bd2b024bf4bc1c62688fcafa9e61991f041fb0e2def4a982/h2-4.2.0-py3-none-any.whl", hash = "sha256:479a53ad425bb29af087f3458a61d30780bc818e4ebcf01f0b536ba916462ed0", size = 60957, upload-time = "2025-02-01T11:02:26.481Z" }, ] +[[package]] +name = "hf-xet" +version = "1.1.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/75/dc/dc091aeeb671e71cbec30e84963f9c0202c17337b24b0a800e7d205543e8/hf_xet-1.1.3.tar.gz", hash = "sha256:a5f09b1dd24e6ff6bcedb4b0ddab2d81824098bb002cf8b4ffa780545fa348c3", size = 488127, upload-time = "2025-06-04T00:47:27.456Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/9b/1f/bc01a4c0894973adebbcd4aa338a06815c76333ebb3921d94dcbd40dae6a/hf_xet-1.1.3-cp37-abi3-macosx_10_12_x86_64.whl", hash = "sha256:c3b508b5f583a75641aebf732853deb058953370ce8184f5dabc49f803b0819b", size = 2256929, upload-time = "2025-06-04T00:47:21.206Z" }, + { url = "https://files.pythonhosted.org/packages/78/07/6ef50851b5c6b45b77a6e018fa299c69a2db3b8bbd0d5af594c0238b1ceb/hf_xet-1.1.3-cp37-abi3-macosx_11_0_arm64.whl", hash = "sha256:b788a61977fbe6b5186e66239e2a329a3f0b7e7ff50dad38984c0c74f44aeca1", size = 2153719, upload-time = "2025-06-04T00:47:19.302Z" }, + { url = "https://files.pythonhosted.org/packages/52/48/e929e6e3db6e4758c2adf0f2ca2c59287f1b76229d8bdc1a4c9cfc05212e/hf_xet-1.1.3-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:fd2da210856444a34aad8ada2fc12f70dabed7cc20f37e90754d1d9b43bc0534", size = 4820519, upload-time = "2025-06-04T00:47:17.244Z" }, + { url = "https://files.pythonhosted.org/packages/28/2e/03f89c5014a5aafaa9b150655f811798a317036646623bdaace25f485ae8/hf_xet-1.1.3-cp37-abi3-manylinux_2_28_aarch64.whl", hash = "sha256:8203f52827e3df65981984936654a5b390566336956f65765a8aa58c362bb841", size = 4964121, upload-time = "2025-06-04T00:47:15.17Z" }, + { url = "https://files.pythonhosted.org/packages/47/8b/5cd399a92b47d98086f55fc72d69bc9ea5e5c6f27a9ed3e0cdd6be4e58a3/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_aarch64.whl", hash = "sha256:30c575a5306f8e6fda37edb866762140a435037365eba7a17ce7bd0bc0216a8b", size = 5283017, upload-time = "2025-06-04T00:47:23.239Z" }, + { url = "https://files.pythonhosted.org/packages/53/e3/2fcec58d2fcfd25ff07feb876f466cfa11f8dcf9d3b742c07fe9dd51ee0a/hf_xet-1.1.3-cp37-abi3-musllinux_1_2_x86_64.whl", hash = "sha256:7c1a6aa6abed1f696f8099aa9796ca04c9ee778a58728a115607de9cc4638ff1", size = 4970349, upload-time = "2025-06-04T00:47:25.383Z" }, + { url = "https://files.pythonhosted.org/packages/53/bf/10ca917e335861101017ff46044c90e517b574fbb37219347b83be1952f6/hf_xet-1.1.3-cp37-abi3-win_amd64.whl", hash = "sha256:b578ae5ac9c056296bb0df9d018e597c8dc6390c5266f35b5c44696003cde9f3", size = 2310934, upload-time = "2025-06-04T00:47:29.632Z" }, +] + [[package]] name = "hpack" version = "4.1.0" @@ -1388,20 +1404,26 @@ wheels = [ [[package]] name = "huggingface-hub" -version = "0.29.1" +version = "0.33.2" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "filelock" }, { name = "fsspec" }, + { name = "hf-xet", marker = "platform_machine == 'aarch64' or platform_machine == 'amd64' or platform_machine == 'arm64' or platform_machine == 'x86_64'" }, { name = "packaging" }, { name = "pyyaml" }, { name = "requests" }, { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/22/37/797d6476f13e5ef6af5fc48a5d641d32b39c37e166ccf40c3714c5854a85/huggingface_hub-0.29.1.tar.gz", hash = "sha256:9524eae42077b8ff4fc459ceb7a514eca1c1232b775276b009709fe2a084f250", size = 389776, upload-time = "2025-02-20T09:24:59.839Z" } +sdist = { url = "https://files.pythonhosted.org/packages/fa/42/8a95c5632080ae312c0498744b2b852195e10b05a20b1be11c5141092f4c/huggingface_hub-0.33.2.tar.gz", hash = "sha256:84221defaec8fa09c090390cd68c78b88e3c4c2b7befba68d3dc5aacbc3c2c5f", size = 426637, upload-time = "2025-07-02T06:26:05.156Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/ae/05/75b90de9093de0aadafc868bb2fa7c57651fd8f45384adf39bd77f63980d/huggingface_hub-0.29.1-py3-none-any.whl", hash = "sha256:352f69caf16566c7b6de84b54a822f6238e17ddd8ae3da4f8f2272aea5b198d5", size = 468049, upload-time = "2025-02-20T09:24:57.962Z" }, + { url = "https://files.pythonhosted.org/packages/44/f4/5f3f22e762ad1965f01122b42dae5bf0e009286e2dba601ce1d0dba72424/huggingface_hub-0.33.2-py3-none-any.whl", hash = "sha256:3749498bfa91e8cde2ddc2c1db92c79981f40e66434c20133b39e5928ac9bcc5", size = 515373, upload-time = "2025-07-02T06:26:03.072Z" }, +] + +[package.optional-dependencies] +inference = [ + { name = "aiohttp" }, ] [[package]] @@ -2966,7 +2988,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -3005,7 +3027,7 @@ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["a2a", "examples", "logfire"] @@ -3107,6 +3129,9 @@ google = [ groq = [ { name = "groq" }, ] +huggingface = [ + { name = "huggingface-hub", extra = ["inference"] }, +] logfire = [ { name = "logfire" }, ] @@ -3163,6 +3188,7 @@ requires-dist = [ { name = "griffe", specifier = ">=1.3.2" }, { name = "groq", marker = "extra == 'groq'", specifier = ">=0.19.0" }, { name = "httpx", specifier = ">=0.27" }, + { name = "huggingface-hub", extras = ["inference"], marker = "extra == 'huggingface'", specifier = ">=0.33.2" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.4" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, @@ -3177,7 +3203,7 @@ requires-dist = [ { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [ From 849aa4c9529e0f0cc9e8756eaf0bec0494edd119 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 16 Jul 2025 12:45:42 -0600 Subject: [PATCH 20/89] Toolsets (#2024) --- docs/agents.md | 2 +- docs/api/ext.md | 5 + docs/api/output.md | 1 + docs/api/toolsets.md | 14 + docs/mcp/client.md | 159 ++--- docs/models/huggingface.md | 2 +- docs/output.md | 6 +- docs/testing.md | 2 +- docs/tools.md | 94 ++- docs/toolsets.md | 633 ++++++++++++++++++ mcp-run-python/README.md | 4 +- mkdocs.yml | 6 +- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 522 ++++++--------- pydantic_ai_slim/pydantic_ai/_output.py | 626 ++++++++--------- pydantic_ai_slim/pydantic_ai/_run_context.py | 22 +- pydantic_ai_slim/pydantic_ai/_tool_manager.py | 190 ++++++ pydantic_ai_slim/pydantic_ai/_utils.py | 19 +- pydantic_ai_slim/pydantic_ai/agent.py | 406 ++++++----- pydantic_ai_slim/pydantic_ai/exceptions.py | 12 + pydantic_ai_slim/pydantic_ai/ext/aci.py | 15 +- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 10 +- pydantic_ai_slim/pydantic_ai/mcp.py | 231 ++++--- pydantic_ai_slim/pydantic_ai/output.py | 25 +- pydantic_ai_slim/pydantic_ai/result.py | 103 +-- pydantic_ai_slim/pydantic_ai/tools.py | 145 +--- .../pydantic_ai/toolsets/__init__.py | 22 + .../pydantic_ai/toolsets/abstract.py | 165 +++++ .../pydantic_ai/toolsets/combined.py | 88 +++ .../pydantic_ai/toolsets/deferred.py | 38 ++ .../pydantic_ai/toolsets/filtered.py | 24 + .../pydantic_ai/toolsets/function.py | 238 +++++++ .../pydantic_ai/toolsets/prefixed.py | 37 + .../pydantic_ai/toolsets/prepared.py | 36 + .../pydantic_ai/toolsets/renamed.py | 42 ++ .../pydantic_ai/toolsets/wrapper.py | 37 + .../test_agent_with_server_not_running.yaml | 391 +++++++++++ tests/ext/test_langchain.py | 45 +- tests/models/test_anthropic.py | 2 +- tests/models/test_gemini.py | 37 +- tests/models/test_model_test.py | 5 +- tests/test_a2a.py | 11 +- tests/test_agent.py | 346 +++++++++- tests/test_examples.py | 42 +- tests/test_logfire.py | 173 +++-- tests/test_mcp.py | 123 ++-- tests/test_streaming.py | 165 ++++- tests/test_tools.py | 222 +++++- tests/test_toolsets.py | 471 +++++++++++++ tests/typed_agent.py | 16 +- 49 files changed, 4617 insertions(+), 1413 deletions(-) create mode 100644 docs/api/ext.md create mode 100644 docs/api/toolsets.md create mode 100644 docs/toolsets.md create mode 100644 pydantic_ai_slim/pydantic_ai/_tool_manager.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/__init__.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/abstract.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/combined.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/deferred.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/filtered.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/function.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/prepared.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/renamed.py create mode 100644 pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py create mode 100644 tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml create mode 100644 tests/test_toolsets.py diff --git a/docs/agents.md b/docs/agents.md index 208d92a583..5b332e0b27 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -826,7 +826,7 @@ with capture_run_messages() as messages: # (2)! result = agent.run_sync('Please get me the volume of a box with size 6.') except UnexpectedModelBehavior as e: print('An error occurred:', e) - #> An error occurred: Tool exceeded max retries count of 1 + #> An error occurred: Tool 'calc_volume' exceeded max retries count of 1 print('cause:', repr(e.__cause__)) #> cause: ModelRetry('Please try again.') print('messages:', messages) diff --git a/docs/api/ext.md b/docs/api/ext.md new file mode 100644 index 0000000000..7f01b44d45 --- /dev/null +++ b/docs/api/ext.md @@ -0,0 +1,5 @@ +# `pydantic_ai.ext` + +::: pydantic_ai.ext.langchain + +::: pydantic_ai.ext.aci diff --git a/docs/api/output.md b/docs/api/output.md index 135ff597bc..bb584608c7 100644 --- a/docs/api/output.md +++ b/docs/api/output.md @@ -10,3 +10,4 @@ - PromptedOutput - TextOutput - StructuredDict + - DeferredToolCalls diff --git a/docs/api/toolsets.md b/docs/api/toolsets.md new file mode 100644 index 0000000000..8146864076 --- /dev/null +++ b/docs/api/toolsets.md @@ -0,0 +1,14 @@ +# `pydantic_ai.toolsets` + +::: pydantic_ai.toolsets + options: + members: + - AbstractToolset + - CombinedToolset + - DeferredToolset + - FilteredToolset + - FunctionToolset + - PrefixedToolset + - RenamedToolset + - PreparedToolset + - WrapperToolset diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 7f8c5fdd6a..15ef46f2e2 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -16,42 +16,54 @@ pip/uv-add "pydantic-ai-slim[mcp]" ## Usage -PydanticAI comes with two ways to connect to MCP servers: +PydanticAI comes with three ways to connect to MCP servers: -- [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport - [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] which connects to an MCP server using the [Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport +- [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport - [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] which runs the server as a subprocess and connects to it using the [stdio](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) transport -Examples of both are shown below; [mcp-run-python](run-python.md) is used as the MCP server in both examples. +Examples of all three are shown below; [mcp-run-python](run-python.md) is used as the MCP server in all examples. -### SSE Client +Each MCP server instance is a [toolset](../toolsets.md) and can be registered with an [`Agent`][pydantic_ai.Agent] using the `toolsets` argument. -[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. +You can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager to open and close connections to all registered servers (and in the case of stdio servers, start and stop the subprocesses) around the context where they'll be used in agent runs. You can also use [`async with server`][pydantic_ai.mcp.MCPServer.__aenter__] to manage the connection or subprocess of a specific server, for example if you'd like to use it with multiple agents. If you don't explicitly enter one of these context managers to set up the server, this will be done automatically when it's needed (e.g. to list the available tools or call a specific tool), but it's more efficient to do so around the entire context where you expect the servers to be used. + +### Streamable HTTP Client + +[`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] connects over HTTP using the +[Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server. !!! note - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before calling [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not managed by PydanticAI. + [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be + running and accepting HTTP connections before running the agent. Running the server is not + managed by Pydantic AI. -The name "HTTP" is used since this implementation will be adapted in future to use the new -[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. +Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. -Before creating the SSE client, we need to run the server (docs [here](run-python.md)): +```python {title="streamable_http_server.py" py="3.10" dunder_name="not_main"} +from mcp.server.fastmcp import FastMCP -```bash {title="terminal (run sse server)"} -deno run \ - -N -R=node_modules -W=node_modules --node-modules-dir=auto \ - jsr:@pydantic/mcp-run-python sse +app = FastMCP() + +@app.tool() +def add(a: int, b: int) -> int: + return a + b + +if __name__ == '__main__': + app.run(transport='streamable-http') ``` -```python {title="mcp_sse_client.py" py="3.10"} -from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerSSE +Then we can create the client: -server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! +```python {title="mcp_streamable_http_client.py" py="3.10"} +from pydantic_ai import Agent +from pydantic_ai.mcp import MCPServerStreamableHTTP +server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -85,43 +97,34 @@ Will display as follows: ![Logfire run python code](../img/logfire-run-python-code.png) -### Streamable HTTP Client +### SSE Client -[`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] connects over HTTP using the -[Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server. +[`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] connects over HTTP using the [HTTP + Server Sent Events transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) to a server. !!! note - [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be - running and accepting HTTP connections before calling - [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers]. Running the server is not - managed by PydanticAI. - -Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. - -```python {title="streamable_http_server.py" py="3.10" dunder_name="not_main"} -from mcp.server.fastmcp import FastMCP + [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. -app = FastMCP() +The name "HTTP" is used since this implementation will be adapted in future to use the new +[Streamable HTTP](https://github.com/modelcontextprotocol/specification/pull/206) currently in development. -@app.tool() -def add(a: int, b: int) -> int: - return a + b +Before creating the SSE client, we need to run the server (docs [here](run-python.md)): -if __name__ == '__main__': - app.run(transport='streamable-http') +```bash {title="terminal (run sse server)"} +deno run \ + -N -R=node_modules -W=node_modules --node-modules-dir=auto \ + jsr:@pydantic/mcp-run-python sse ``` -Then we can create the client: - -```python {title="mcp_streamable_http_client.py" py="3.10"} +```python {title="mcp_sse_client.py" py="3.10"} from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerStreamableHTTP +from pydantic_ai.mcp import MCPServerSSE + +server = MCPServerSSE(url='http://localhost:3001/sse') # (1)! +agent = Agent('openai:gpt-4o', toolsets=[server]) # (2)! -server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! -agent = Agent('openai:gpt-4o', mcp_servers=[server]) # (2)! async def main(): - async with agent.run_mcp_servers(): # (3)! + async with agent: # (3)! result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -137,9 +140,6 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n The other transport offered by MCP is the [stdio transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) where the server is run as a subprocess and communicates with the client over `stdin` and `stdout`. In this case, you'd use the [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] class. -!!! note - When using [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio] servers, the [`agent.run_mcp_servers()`][pydantic_ai.Agent.run_mcp_servers] context manager is responsible for starting and stopping the server. - ```python {title="mcp_stdio_client.py" py="3.10"} from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio @@ -156,11 +156,11 @@ server = MCPServerStdio( # (1)! 'stdio', ] ) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025. @@ -188,23 +188,23 @@ from pydantic_ai.tools import RunContext async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, - args: dict[str, Any], + name: str, + tool_args: dict[str, Any], ) -> ToolResult: """A tool call processor that passes along the deps.""" - return await call_tool(tool_name, args, metadata={'deps': ctx.deps}) + return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['mcp_server.py'], process_tool_call=process_tool_call) agent = Agent( model=TestModel(call_tools=['echo_deps']), deps_type=int, - mcp_servers=[server] + toolsets=[server] ) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Echo with deps set to 42', deps=42) print(result.output) #> {"echo_deps":{"echo":"This is an echo message","deps":42}} @@ -214,15 +214,7 @@ async def main(): When connecting to multiple MCP servers that might provide tools with the same name, you can use the `tool_prefix` parameter to avoid naming conflicts. This parameter adds a prefix to all tool names from a specific server. -### How It Works - -- If `tool_prefix` is set, all tools from that server will be prefixed with `{tool_prefix}_` -- When listing tools, the prefixed names are shown to the model -- When calling tools, the prefix is automatically removed before sending the request to the server - -This allows you to use multiple servers that might have overlapping tool names without conflicts. - -### Example with HTTP Server +This allows you to use multiple servers that might have overlapping tool names without conflicts: ```python {title="mcp_tool_prefix_http_client.py" py="3.10"} from pydantic_ai import Agent @@ -242,41 +234,9 @@ calculator_server = MCPServerSSE( # Both servers might have a tool named 'get_data', but they'll be exposed as: # - 'weather_get_data' # - 'calc_get_data' -agent = Agent('openai:gpt-4o', mcp_servers=[weather_server, calculator_server]) -``` - -### Example with Stdio Server - -```python {title="mcp_tool_prefix_stdio_client.py" py="3.10"} -from pydantic_ai import Agent -from pydantic_ai.mcp import MCPServerStdio - -python_server = MCPServerStdio( - 'deno', - args=[ - 'run', - '-N', - 'jsr:@pydantic/mcp-run-python', - 'stdio', - ], - tool_prefix='py' # Tools will be prefixed with 'py_' -) - -js_server = MCPServerStdio( - 'node', - args=[ - 'run', - 'mcp-js-server.js', - 'stdio', - ], - tool_prefix='js' # Tools will be prefixed with 'js_' -) - -agent = Agent('openai:gpt-4o', mcp_servers=[python_server, js_server]) +agent = Agent('openai:gpt-4o', toolsets=[weather_server, calculator_server]) ``` -When the model interacts with these servers, it will see the prefixed tool names, but the prefixes will be automatically handled when making tool calls. - ## MCP Sampling !!! info "What is MCP Sampling?" @@ -312,6 +272,8 @@ Pydantic AI supports sampling as both a client and server. See the [server](./se Sampling is automatically supported by Pydantic AI agents when they act as a client. +To be able to use sampling, an MCP server instance needs to have a [`sampling_model`][pydantic_ai.mcp.MCPServerStdio.sampling_model] set. This can be done either directly on the server using the constructor keyword argument or the property, or by using [`agent.set_mcp_sampling_model()`][pydantic_ai.Agent.set_mcp_sampling_model] to set the agent's model or one specified as an argument as the sampling model on all MCP servers registered with that agent. + Let's say we have an MCP server that wants to use sampling (in this case to generate an SVG as per the tool arguments). ??? example "Sampling MCP Server" @@ -359,11 +321,12 @@ from pydantic_ai import Agent from pydantic_ai.mcp import MCPServerStdio server = MCPServerStdio(command='python', args=['generate_svg.py']) -agent = Agent('openai:gpt-4o', mcp_servers=[server]) +agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: + agent.set_mcp_sampling_model() result = await agent.run('Create an image of a robot in a punk style.') print(result.output) #> Image file written to robot_punk.svg. diff --git a/docs/models/huggingface.md b/docs/models/huggingface.md index 61f8eef35f..4e6e2ee795 100644 --- a/docs/models/huggingface.md +++ b/docs/models/huggingface.md @@ -69,7 +69,7 @@ agent = Agent(model) ## Custom Hugging Face client [`HuggingFaceProvider`][pydantic_ai.providers.huggingface.HuggingFaceProvider] also accepts a custom -[`AsyncInferenceClient`][huggingface_hub.AsyncInferenceClient] client via the `hf_client` parameter, so you can customise +[`AsyncInferenceClient`](https://huggingface.co/docs/huggingface_hub/v0.29.3/en/package_reference/inference_client#huggingface_hub.AsyncInferenceClient) client via the `hf_client` parameter, so you can customise the `headers`, `bill_to` (billing to an HF organization you're a member of), `base_url` etc. as defined in the [Hugging Face Hub python library docs](https://huggingface.co/docs/huggingface_hub/package_reference/inference_client). diff --git a/docs/output.md b/docs/output.md index caa0c14b0f..2391a5efde 100644 --- a/docs/output.md +++ b/docs/output.md @@ -199,8 +199,8 @@ async def hand_off_to_sql_agent(ctx: RunContext, query: str) -> list[Row]: return output except UnexpectedModelBehavior as e: # Bubble up potentially retryable errors to the router agent - if (cause := e.__cause__) and hasattr(cause, 'tool_retry'): - raise ModelRetry(f'SQL agent failed: {cause.tool_retry.content}') from e + if (cause := e.__cause__) and isinstance(cause, ModelRetry): + raise ModelRetry(f'SQL agent failed: {cause.message}') from e else: raise @@ -276,6 +276,8 @@ In the default Tool Output mode, the output JSON schema of each output type (or If you'd like to change the name of the output tool, pass a custom description to aid the model, or turn on or off strict mode, you can wrap the type(s) in the [`ToolOutput`][pydantic_ai.output.ToolOutput] marker class and provide the appropriate arguments. Note that by default, the description is taken from the docstring specified on a Pydantic model or output function, so specifying it using the marker class is typically not necessary. +To dynamically modify or filter the available output tools during an agent run, you can define an agent-wide `prepare_output_tools` function that will be called ahead of each step of a run. This function should be of type [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc], which takes the [`RunContext`][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and returns a new list of tool definitions (or `None` to disable all tools for that step). This is analogous to the [`prepare_tools` function](tools.md#prepare-tools) for non-output tools. + ```python {title="tool_output.py"} from pydantic import BaseModel diff --git a/docs/testing.md b/docs/testing.md index ac7d2ea249..b40bb1dc9c 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -10,7 +10,7 @@ Unless you're really sure you know better, you'll probably want to follow roughl * If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) * Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures * Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the usage, latency and variability of real LLM calls -* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace your model inside your application logic +* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace an agent's model, dependencies, or toolsets inside your application logic * Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally ### Unit testing with `TestModel` diff --git a/docs/tools.md b/docs/tools.md index 44133f5759..134a8f96ea 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -1,25 +1,30 @@ # Function Tools -Function tools provide a mechanism for models to retrieve extra information to help them generate a response. +Function tools provide a mechanism for models to perform actions and retrieve extra information to help them generate a response. -They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the system prompt, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. +They're useful when you want to enable the model to take some action and use the result, when it is impractical or impossible to put all the context an agent might need into the instructions, or when you want to make agents' behavior more deterministic or reliable by deferring some of the logic required to generate a response to another (not necessarily AI-powered) tool. If you want a model to be able to call a function as its final action, without the result being sent back to the model, you can use an [output function](output.md#output-functions) instead. -!!! info "Function tools vs. RAG" - Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. - - The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) - There are a number of ways to register tools with an agent: * via the [`@agent.tool`][pydantic_ai.Agent.tool] decorator — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] * via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] * via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] -## Registering Function Tools via Decorator +For more advanced use cases, the [toolsets](toolsets.md) feature lets you manage collections of tools (built by you or providd by an [MCP server](mcp/client.md) or other [third party](#third-party-tools)) and register them with an agent in one go via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent`. + +!!! info "Function tools vs. RAG" + Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. + + The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) -`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent context. +!!! info "Function Tools vs. Structured Outputs" + As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for [structured output](output.md) when using the default [tool output mode](output.md#tool-output), thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output. + +## Registering via Decorator {#registering-function-tools-via-decorator} + +`@agent.tool` is considered the default decorator since in the majority of cases tools will need access to the agent [context][pydantic_ai.tools.RunContext]. Here's an example using both: @@ -58,7 +63,7 @@ print(dice_result.output) 1. This is a pretty simple task, so we can use the fast and cheap Gemini flash model. 2. We pass the user's name as the dependency, to keep things simple we use just the name as a string as the dependency. -3. This tool doesn't need any context, it just returns a random number. You could probably use a dynamic system prompt in this case. +3. This tool doesn't need any context, it just returns a random number. You could probably use dynamic instructions in this case. 4. This tool needs the player's name, so it uses `RunContext` to access dependencies which are just the player's name in this case. 5. Run the agent, passing the player's name as the dependency. @@ -176,7 +181,7 @@ sequenceDiagram Note over Agent: Game session complete ``` -## Registering Function Tools via Agent Argument +## Registering via Agent Argument {#registering-function-tools-via-agent-argument} As well as using the decorators, we can register tools via the `tools` argument to the [`Agent` constructor][pydantic_ai.Agent.__init__]. This is useful when you want to reuse tools, and can also give more fine-grained control over the tools. @@ -232,7 +237,7 @@ print(dice_result['b'].output) _(This example is complete, it can be run "as is")_ -## Function Tool Output +## Tool Output {#function-tool-output} Tools can return anything that Pydantic can serialize to JSON, as well as audio, video, image or document content depending on the types of [multi-modal input](input.md) the model supports: @@ -353,11 +358,7 @@ print(result.output) This separation allows you to provide rich context to the model while maintaining clean, structured return values for your application logic. -## Function Tools vs. Structured Outputs - -As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for structured responses, thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output. - -## Function tools and schema +## Tool Schema {#function-tools-and-schema} Function parameters are extracted from the function signature, and all parameters except `RunContext` are used to build the schema for that tool call. @@ -469,7 +470,9 @@ print(test_model.last_model_request_parameters.function_tools) _(This example is complete, it can be run "as is")_ -If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the `Tool.from_schema` function. With this you provide the name, description and JSON schema for the function directly: +### Custom Tool Schema + +If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the [`Tool.from_schema`][pydantic_ai.Tool.from_schema] function. With this you provide the name, description and JSON schema for the function directly: ```python from pydantic_ai import Agent, Tool @@ -505,7 +508,7 @@ print(result.output) Please note that validation of the tool arguments will not be performed, and this will pass all arguments as keyword arguments. -## Dynamic Function tools {#tool-prepare} +## Dynamic Tools {#tool-prepare} Tools can optionally be defined with another function: `prepare`, which is called at each step of a run to customize the definition of the tool passed to the model, or omit the tool completely from that step. @@ -606,14 +609,15 @@ print(test_model.last_model_request_parameters.function_tools) _(This example is complete, it can be run "as is")_ -## Agent-wide Dynamic Tool Preparation {#prepare-tools} +### Agent-wide Dynamic Tools {#prepare-tools} In addition to per-tool `prepare` methods, you can also define an agent-wide `prepare_tools` function. This function is called at each step of a run and allows you to filter or modify the list of all tool definitions available to the agent for that step. This is especially useful if you want to enable or disable multiple tools at once, or apply global logic based on the current context. The `prepare_tools` function should be of type [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc], which takes the [`RunContext`][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and returns a new list of tool definitions (or `None` to disable all tools for that step). !!! note - The list of tool definitions passed to `prepare_tools` includes both regular tools and tools from any MCP servers attached to the agent. + The list of tool definitions passed to `prepare_tools` includes both regular function tools and tools from any [toolsets](toolsets.md) registered to the agent, but not [output tools](output.md#tool-output). + To modify output tools, you can set a `prepare_output_tools` function instead. Here's an example that makes all tools strict if the model is an OpenAI model: @@ -724,11 +728,11 @@ Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception ### MCP Tools {#mcp-tools} -See the [MCP Client](./mcp/client.md) documentation for how to use MCP servers with Pydantic AI. +See the [MCP Client](./mcp/client.md) documentation for how to use MCP servers with Pydantic AI as [toolsets](toolsets.md). ### LangChain Tools {#langchain-tools} -If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the `pydancic_ai.ext.langchain.tool_from_langchain` convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. +If you'd like to use a tool from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the [`tool_from_langchain`][pydantic_ai.ext.langchain.tool_from_langchain] convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. You will need to install the `langchain-community` package and any others required by the tool in question. @@ -740,6 +744,7 @@ from langchain_community.tools import DuckDuckGoSearchRun from pydantic_ai import Agent from pydantic_ai.ext.langchain import tool_from_langchain + search = DuckDuckGoSearchRun() search_tool = tool_from_langchain(search) @@ -755,9 +760,25 @@ print(result.output) 1. The release date of this game is the 30th of May 2025, which is after the knowledge cutoff for Gemini 2.0 (August 2024). +If you'd like to use multiple LangChain tools or a LangChain [toolkit](https://python.langchain.com/docs/concepts/tools/#toolkits), you can use the [`LangChainToolset`][pydantic_ai.ext.langchain.LangChainToolset] [toolset](toolsets.md) which takes a list of LangChain tools: + +```python {test="skip"} +from langchain_community.agent_toolkits import SlackToolkit + +from pydantic_ai import Agent +from pydantic_ai.ext.langchain import LangChainToolset + + +toolkit = SlackToolkit() +toolset = LangChainToolset(toolkit.get_tools()) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +# ... +``` + ### ACI.dev Tools {#aci-tools} -If you'd like to use a tool from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the `pydancic_ai.ext.aci.tool_from_aci` convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. +If you'd like to use a tool from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the [`tool_from_aci`][pydantic_ai.ext.aci.tool_from_aci] convenience method. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. You will need to install the `aci-sdk` package, set your ACI API key in the `ACI_API_KEY` environment variable, and pass your ACI "linked account owner ID" to the function. @@ -769,14 +790,15 @@ import os from pydantic_ai import Agent from pydantic_ai.ext.aci import tool_from_aci + tavily_search = tool_from_aci( 'TAVILY__SEARCH', - linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID') + linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), ) agent = Agent( 'google-gla:gemini-2.0-flash', - tools=[tavily_search] + tools=[tavily_search], ) result = agent.run_sync('What is the release date of Elden Ring Nightreign?') # (1)! @@ -785,3 +807,23 @@ print(result.output) ``` 1. The release date of this game is the 30th of May 2025, which is after the knowledge cutoff for Gemini 2.0 (August 2024). + +If you'd like to use multiple ACI.dev tools, you can use the [`ACIToolset`][pydantic_ai.ext.aci.ACIToolset] [toolset](toolsets.md) which takes a list of ACI tool names as well as the `linked_account_owner_id`: + +```python {test="skip"} +import os + +from pydantic_ai import Agent +from pydantic_ai.ext.aci import ACIToolset + + +toolset = ACIToolset( + [ + 'OPEN_WEATHER_MAP__CURRENT_WEATHER', + 'OPEN_WEATHER_MAP__FORECAST', + ], + linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), +) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +``` diff --git a/docs/toolsets.md b/docs/toolsets.md new file mode 100644 index 0000000000..fa7073798b --- /dev/null +++ b/docs/toolsets.md @@ -0,0 +1,633 @@ + +# Toolsets + +A toolset represents a collection of [tools](tools.md) that can be registered with an agent in one go. They can be reused by different agents, swapped out at runtime or during testing, and composed in order to dynamically filter which tools are available, modify tool definitions, or change tool execution behavior. A toolset can contain locally defined functions, depend on an external service to provide them, or implement custom logic to list available tools and handle them being called. + +Toolsets are used (among many other things) to define [MCP servers](mcp/client.md) available to an agent. Pydantic AI includes many kinds of toolsets which are described below, and you can define a [custom toolset](#building-a-custom-toolset) by inheriting from the [`AbstractToolset`][pydantic_ai.toolsets.AbstractToolset] class. + +The toolsets that will be available during an agent run can be specified in three different ways: + +* at agent construction time, via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent` +* at agent run time, via the `toolsets` keyword argument to [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. These toolsets will be additional to those provided to the `Agent` constructor +* as a contextual override, via the `toolsets` keyword argument to the [`agent.override()`][pydantic_ai.Agent.iter] context manager. These toolsets will replace those provided at agent construction or run time during the life of the context manager + +```python {title="toolsets.py"} +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import FunctionToolset + + +def agent_tool(): + return "I'm registered directly on the agent" + + +def extra_tool(): + return "I'm passed as an extra tool for a specific run" + + +def override_tool(): + return "I override all other tools" + + +agent_toolset = FunctionToolset(tools=[agent_tool]) # (1)! +extra_toolset = FunctionToolset(tools=[extra_tool]) +override_toolset = FunctionToolset(tools=[override_tool]) + +test_model = TestModel() # (2)! +agent = Agent(test_model, toolsets=[agent_toolset]) + +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['agent_tool'] + +result = agent.run_sync('What tools are available?', toolsets=[extra_toolset]) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['agent_tool', 'extra_tool'] + +with agent.override(toolsets=[override_toolset]): + result = agent.run_sync('What tools are available?', toolsets=[extra_toolset]) # (3)! + print([t.name for t in test_model.last_model_request_parameters.function_tools]) + #> ['override_tool'] +``` + +1. The [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset] will be explained in detail in the next section. +2. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +3. This `extra_toolset` will be ignored because we're inside an override context. + +_(This example is complete, it can be run "as is")_ + +## Function Toolset + +As the name suggests, a [`FunctionToolset`][pydantic_ai.toolsets.FunctionToolset] makes locally defined functions available as tools. + +Functions can be added as tools in three different ways: + +* via the [`@toolset.tool`][pydantic_ai.toolsets.FunctionToolset.tool] decorator +* via the [`tools`][pydantic_ai.toolsets.FunctionToolset.__init__] keyword argument to the constructor which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] +* via the [`toolset.add_function()`][pydantic_ai.toolsets.FunctionToolset.add_function] and [`toolset.add_tool()`][pydantic_ai.toolsets.FunctionToolset.add_tool] methods which can take a plain function or an instance of [`Tool`][pydantic_ai.tools.Tool] respectively + +Functions registered in any of these ways can define an initial `ctx: RunContext` argument in order to receive the agent [context][pydantic_ai.tools.RunContext]. The `add_function()` and `add_tool()` methods can also be used from a tool function to dynamically register new tools during a run to be available in future run steps. + +```python {title="function_toolset.py"} +from datetime import datetime + +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import FunctionToolset + + +def temperature_celsius(city: str) -> float: + return 21.0 + + +def temperature_fahrenheit(city: str) -> float: + return 69.8 + + +weather_toolset = FunctionToolset(tools=[temperature_celsius, temperature_fahrenheit]) + + +@weather_toolset.tool +def conditions(ctx: RunContext, city: str) -> str: + if ctx.run_step % 2 == 0: + return "It's sunny" + else: + return "It's raining" + + +datetime_toolset = FunctionToolset() +datetime_toolset.add_function(lambda: datetime.now(), name='now') + +test_model = TestModel() # (1)! +agent = Agent(test_model) + +result = agent.run_sync('What tools are available?', toolsets=[weather_toolset]) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['temperature_celsius', 'temperature_fahrenheit', 'conditions'] + +result = agent.run_sync('What tools are available?', toolsets=[datetime_toolset]) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['now'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +## Toolset Composition + +Toolsets can be composed to dynamically filter which tools are available, modify tool definitions, or change tool execution behavior. Multiple toolsets can also be combined into one. + +### Combining Toolsets + +[`CombinedToolset`][pydantic_ai.toolsets.CombinedToolset] takes a list of toolsets and lets them be used as one. + +```python {title="combined_toolset.py" requires="function_toolset.py"} +from function_toolset import weather_toolset, datetime_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import CombinedToolset + + +combined_toolset = CombinedToolset([weather_toolset, datetime_toolset]) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[combined_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['temperature_celsius', 'temperature_fahrenheit', 'conditions', 'now'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Filtering Tools + +[`FilteredToolset`][pydantic_ai.toolsets.FilteredToolset] wraps a toolset and filters available tools ahead of each step of the run based on a user-defined function that is passed the agent [run context][pydantic_ai.tools.RunContext] and each tool's [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] and returns a boolean to indicate whether or not a given tool should be available. + +To easily chain different modifications, you can also call [`filtered()`][pydantic_ai.toolsets.AbstractToolset.filtered] on any toolset instead of directly constructing a `FilteredToolset`. + +```python {title="filtered_toolset.py" requires="function_toolset.py,combined_toolset.py"} +from combined_toolset import combined_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + +filtered_toolset = combined_toolset.filtered(lambda ctx, tool_def: 'fahrenheit' not in tool_def.name) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[filtered_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['weather_temperature_celsius', 'weather_conditions', 'datetime_now'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Prefixing Tool Names + +[`PrefixedToolset`][pydantic_ai.toolsets.PrefixedToolset] wraps a toolset and adds a prefix to each tool name to prevent tool name conflicts between different toolsets. + +To easily chain different modifications, you can also call [`prefixed()`][pydantic_ai.toolsets.AbstractToolset.prefixed] on any toolset instead of directly constructing a `PrefixedToolset`. + +```python {title="combined_toolset.py" requires="function_toolset.py"} +from function_toolset import weather_toolset, datetime_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import CombinedToolset + + +combined_toolset = CombinedToolset( + [ + weather_toolset.prefixed('weather'), + datetime_toolset.prefixed('datetime') + ] +) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[combined_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +""" +[ + 'weather_temperature_celsius', + 'weather_temperature_fahrenheit', + 'weather_conditions', + 'datetime_now', +] +""" +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Renaming Tools + +[`RenamedToolset`][pydantic_ai.toolsets.RenamedToolset] wraps a toolset and lets you rename tools using a dictionary mapping new names to original names. This is useful when the names provided by a toolset are ambiguous or would conflict with tools defined by other toolsets, but [prefixing them](#prefixing-tool-names) creates a name that is unnecessarily long or could be confusing to the model. + +To easily chain different modifications, you can also call [`renamed()`][pydantic_ai.toolsets.AbstractToolset.renamed] on any toolset instead of directly constructing a `RenamedToolset`. + +```python {title="renamed_toolset.py" requires="function_toolset.py,combined_toolset.py"} +from combined_toolset import combined_toolset + +from pydantic_ai import Agent +from pydantic_ai.models.test import TestModel + + +renamed_toolset = combined_toolset.renamed( + { + 'current_time': 'datetime_now', + 'temperature_celsius': 'weather_temperature_celsius', + 'temperature_fahrenheit': 'weather_temperature_fahrenheit' + } +) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[renamed_toolset]) +result = agent.run_sync('What tools are available?') +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +""" +['temperature_celsius', 'temperature_fahrenheit', 'weather_conditions', 'current_time'] +""" +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +_(This example is complete, it can be run "as is")_ + +### Preparing Tool Definitions + +[`PreparedToolset`][pydantic_ai.toolsets.PreparedToolset] lets you modify the entire list of available tools ahead of each step of the agent run using a user-defined function that takes the agent [run context][pydantic_ai.tools.RunContext] and a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition] and returns a list of modified `ToolDefinition`s. + +This is the toolset-specific equivalent of the [`prepare_tools`](tools.md#prepare-tools) argument to `Agent` that prepares all tool definitions registered to an agent across toolsets. + +Note that it is not possible to add or rename tools using `PreparedToolset`. Instead, you can use [`FunctionToolset.add_function()`](#function-toolset) or [`RenamedToolset`](#renaming-tools). + +To easily chain different modifications, you can also call [`prepared()`][pydantic_ai.toolsets.AbstractToolset.prepared] on any toolset instead of directly constructing a `PreparedToolset`. + +```python {title="prepared_toolset.py" requires="function_toolset.py,combined_toolset.py,renamed_toolset.py"} +from dataclasses import replace +from typing import Union + +from renamed_toolset import renamed_toolset + +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition + +descriptions = { + 'temperature_celsius': 'Get the temperature in degrees Celsius', + 'temperature_fahrenheit': 'Get the temperature in degrees Fahrenheit', + 'weather_conditions': 'Get the current weather conditions', + 'current_time': 'Get the current time', +} + +async def add_descriptions(ctx: RunContext, tool_defs: list[ToolDefinition]) -> Union[list[ToolDefinition], None]: + return [ + replace(tool_def, description=description) + if (description := descriptions.get(tool_def.name, None)) + else tool_def + for tool_def + in tool_defs + ] + +prepared_toolset = renamed_toolset.prepared(add_descriptions) + +test_model = TestModel() # (1)! +agent = Agent(test_model, toolsets=[prepared_toolset]) +result = agent.run_sync('What tools are available?') +print(test_model.last_model_request_parameters.function_tools) +""" +[ + ToolDefinition( + name='temperature_celsius', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + description='Get the temperature in degrees Celsius', + ), + ToolDefinition( + name='temperature_fahrenheit', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + description='Get the temperature in degrees Fahrenheit', + ), + ToolDefinition( + name='weather_conditions', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'city': {'type': 'string'}}, + 'required': ['city'], + 'type': 'object', + }, + description='Get the current weather conditions', + ), + ToolDefinition( + name='current_time', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {}, + 'type': 'object', + }, + description='Get the current time', + ), +] +""" +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. + +### Wrapping a Toolset + +[`WrapperToolset`][pydantic_ai.toolsets.WrapperToolset] wraps another toolset and delegates all responsibility to it. + +To easily chain different modifications, you can also call [`wrap()`][pydantic_ai.toolsets.AbstractToolset.wrap] on any toolset instead of directly constructing an instance of (a subclass of) `WrapperToolset`. + +`WrapperToolset` is a no-op by default, but enables some useful abilities: + +#### Changing Tool Execution + +You can subclass `WrapperToolset` to change the wrapped toolset's tool execution behavior by overriding the [`call_tool()`][pydantic_ai.toolsets.AbstractToolset.call_tool] method. + +```python {title="logging_toolset.py" requires="function_toolset.py,combined_toolset.py,renamed_toolset.py,prepared_toolset.py"} +from typing_extensions import Any + +from prepared_toolset import prepared_toolset + +from pydantic_ai.agent import Agent +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import RunContext +from pydantic_ai.toolsets import WrapperToolset, ToolsetTool + +LOG = [] + +class LoggingToolset(WrapperToolset): + async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: + LOG.append(f'Calling tool {name!r} with args: {tool_args!r}') + try: + result = await super().call_tool(name, tool_args, ctx, tool) + LOG.append(f'Finished calling tool {name!r} with result: {result!r}') + except Exception as e: + LOG.append(f'Error calling tool {name!r}: {e}') + raise e + else: + return result + + +logging_toolset = prepared_toolset.wrap(LoggingToolset) + +agent = Agent(TestModel(), toolsets=[logging_toolset]) # (1)! +result = agent.run_sync('Call all the tools') +print(LOG) +""" +[ + "Calling tool 'temperature_celsius' with args: {'city': 'a'}", + "Calling tool 'temperature_fahrenheit' with args: {'city': 'a'}", + "Calling tool 'weather_conditions' with args: {'city': 'a'}", + "Calling tool 'current_time' with args: {}", + "Finished calling tool 'temperature_celsius' with result: 21.0", + "Finished calling tool 'temperature_fahrenheit' with result: 69.8", + 'Finished calling tool \'weather_conditions\' with result: "It\'s raining"', + "Finished calling tool 'current_time' with result: datetime.datetime(...)", +] +""" +``` + +1. We use [`TestModel`][pydantic_ai.models.test.TestModel] here as it will automatically call each tool. + +_(This example is complete, it can be run "as is")_ + +#### Modifying Toolsets During a Run + +You can change the `WrapperToolset`'s `wrapped` property during an agent run to swap out one toolset for another starting at the next run step. + +To add or remove available toolsets, you can wrap a [`CombinedToolset`](#combining-toolsets) and replace it during the run with one that can include fewer, more, or entirely different toolsets. + +```python {title="wrapper_toolset.py" requires="function_toolset.py"} +from function_toolset import weather_toolset, datetime_toolset + +from pydantic_ai import Agent, RunContext +from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import WrapperToolset, FunctionToolset + +togglable_toolset = WrapperToolset(weather_toolset) + +def toggle(ctx: RunContext[WrapperToolset]): + if ctx.deps.wrapped == weather_toolset: + ctx.deps.wrapped = datetime_toolset + else: + ctx.deps.wrapped = weather_toolset + +test_model = TestModel() # (1)! +agent = Agent( + test_model, + deps_type=WrapperToolset, # (2)! + toolsets=[togglable_toolset, FunctionToolset([toggle])] +) +result = agent.run_sync('Toggle the toolset', deps=togglable_toolset) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) # (3)! +#> ['now', 'toggle'] + +result = agent.run_sync('Toggle the toolset', deps=togglable_toolset) +print([t.name for t in test_model.last_model_request_parameters.function_tools]) +#> ['temperature_celsius', 'temperature_fahrenheit', 'conditions', 'toggle'] +``` + +1. We're using [`TestModel`][pydantic_ai.models.test.TestModel] here because it makes it easy to see which tools were available on each run. +2. We're using the agent's dependencies to give the `toggle` tool access to the `togglable_toolset` via the `RunContext` argument. +3. This shows the available tools _after_ the `toggle` tool was executed, as the "last model request" was the one that returned the `toggle` tool result to the model. + +## Building a Custom Toolset + +To define a fully custom toolset with its own logic to list available tools and handle them being called, you can subclass [`AbstractToolset`][pydantic_ai.toolsets.AbstractToolset] and implement the [`get_tools()`][pydantic_ai.toolsets.AbstractToolset.get_tools] and [`call_tool()`][pydantic_ai.toolsets.AbstractToolset.call_tool] methods. + +If you want to reuse a network connection or session across tool listings and calls during an agent run step, you can implement [`__aenter__()`][pydantic_ai.toolsets.AbstractToolset.__aenter__] and [`__aexit__()`][pydantic_ai.toolsets.AbstractToolset.__aexit__], which will be called when the agent that uses the toolset is itself entered using the [`async with agent`][pydantic_ai.Agent.__aenter__] context manager. + +### Deferred Toolset + +A deferred tool is one that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via a protocol like [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + +!!! note + This is not typically something you need to bother with, unless you are implementing support for such a protocol between an upstream tool provider and Pydantic AI. + +When the model calls a deferred tool, the agent run ends with a [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] object containing the deferred tool call names and arguments, which is expected to be returned to the upstream tool provider. This upstream service is then expected to generate a response for each tool call and start a new Pydantic AI agent run with the message history and new [`ToolReturnPart`s][pydantic_ai.messages.ToolReturnPart] corresponding to each deferred call, after which the run will continue. + +To enable an agent to call deferred tools, you create a [`DeferredToolset`][pydantic_ai.toolsets.DeferredToolset], pass it a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition], and provide it to the agent using one of the methods described above. Additionally, you need to add `DeferredToolCalls` to the `Agent`'s [output types](output.md#structured-output) so that the agent run's output type is correctly inferred. Finally, you should handle the possible `DeferredToolCalls` result by returning it to the upstream tool provider. + +If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. + +To demonstrate, let us first define a simple agent _without_ deferred tools: + +```python {title="deferred_toolset_agent.py"} +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.toolsets.function import FunctionToolset + +toolset = FunctionToolset() + + +@toolset.tool +def get_default_language(): + return 'en-US' + + +@toolset.tool +def get_user_name(): + return 'David' + + +class PersonalizedGreeting(BaseModel): + greeting: str + language_code: str + + +agent = Agent('openai:gpt-4o', toolsets=[toolset], output_type=PersonalizedGreeting) + +result = agent.run_sync('Greet the user in a personalized way') +print(repr(result.output)) +#> PersonalizedGreeting(greeting='Hello, David!', language_code='en-US') +``` + +Next, let's define an function for a hypothetical "run agent" API endpoint that can be called by the frontend and takes a list of messages to send to the model plus a dict of frontend tool names and descriptions. This is where `DeferredToolset` and `DeferredToolCalls` come in: + +```python {title="deferred_toolset_api.py" requires="deferred_toolset_agent.py"} +from deferred_toolset_agent import agent, PersonalizedGreeting + +from typing import Union + +from pydantic_ai.output import DeferredToolCalls +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets import DeferredToolset +from pydantic_ai.messages import ModelMessage + +def run_agent( + messages: list[ModelMessage] = [], frontend_tools: list[ToolDefinition] = {} +) -> tuple[Union[PersonalizedGreeting, DeferredToolCalls], list[ModelMessage]]: + deferred_toolset = DeferredToolset(frontend_tools) + result = agent.run_sync( + toolsets=[deferred_toolset], # (1)! + output_type=[agent.output_type, DeferredToolCalls], # (2)! + message_history=messages, # (3)! + ) + return result.output, result.new_messages() +``` + +1. As mentioned above, these `toolsets` are additional to those provided to the `Agent` constructor +2. As mentioned above, this `output_type` overrides the one provided to the `Agent` constructor, so we have to make sure to not lose it +3. We don't include an `user_prompt` keyword argument as we expect the frontend to provide it via `messages` + +Now, imagine that the code below is implemented on the frontend, and `run_agent` stands in for an API call to the backend that runs the agent. This is where we actually execute the deferred tool calls and start a new run with the new result included: + +```python {title="deferred_tools.py" requires="deferred_toolset_agent.py,deferred_toolset_api.py"} +from deferred_toolset_api import run_agent + +from pydantic_ai.messages import ModelMessage, ModelRequest, RetryPromptPart, ToolReturnPart, UserPromptPart +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.output import DeferredToolCalls + +frontend_tool_definitions = [ + ToolDefinition( + name='get_preferred_language', + parameters_json_schema={'type': 'object', 'properties': {'default_language': {'type': 'string'}}}, + description="Get the user's preferred language from their browser", + ) +] +def get_preferred_language(default_language: str) -> str: + return 'es-MX' # (1)! +frontend_tool_functions = {'get_preferred_language': get_preferred_language} + +messages: list[ModelMessage] = [ + ModelRequest( + parts=[ + UserPromptPart(content='Greet the user in a personalized way') + ] + ) +] + +final_output = None +while True: + output, new_messages = run_agent(messages, frontend_tool_definitions) + messages += new_messages + + if not isinstance(output, DeferredToolCalls): + final_output = output + break + + print(output.tool_calls) + """ + [ + ToolCallPart( + tool_name='get_preferred_language', + args={'default_language': 'en-US'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + """ + for tool_call in output.tool_calls: + if function := frontend_tool_functions.get(tool_call.tool_name): + part = ToolReturnPart( + tool_name=tool_call.tool_name, + content=function(**tool_call.args_as_dict()), + tool_call_id=tool_call.tool_call_id, + ) + else: + part = RetryPromptPart( + tool_name=tool_call.tool_name, + content=f'Unknown tool {tool_call.tool_name!r}', + tool_call_id=tool_call.tool_call_id, + ) + messages.append(ModelRequest(parts=[part])) + +print(repr(final_output)) +""" +PersonalizedGreeting(greeting='Hola, David! Espero que tengas un gran día!', language_code='es-MX') +""" +``` + +1. Imagine that this returns [`navigator.language`](https://developer.mozilla.org/en-US/docs/Web/API/Navigator/language) + +_(This example is complete, it can be run "as is")_ + +## Third-Party Toolsets + +### MCP Servers + +See the [MCP Client](./mcp/client.md) documentation for how to use MCP servers with Pydantic AI. + +### LangChain Tools {#langchain-tools} + +If you'd like to use tools or a [toolkit](https://python.langchain.com/docs/concepts/tools/#toolkits) from LangChain's [community tool library](https://python.langchain.com/docs/integrations/tools/) with Pydantic AI, you can use the [`LangChainToolset`][pydantic_ai.ext.langchain.LangChainToolset] which takes a list of LangChain tools. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the LangChain tool, and up to the LangChain tool to raise an error if the arguments are invalid. + +You will need to install the `langchain-community` package and any others required by the tools in question. + +```python {test="skip"} +from langchain_community.agent_toolkits import SlackToolkit + +from pydantic_ai import Agent +from pydantic_ai.ext.langchain import LangChainToolset + + +toolkit = SlackToolkit() +toolset = LangChainToolset(toolkit.get_tools()) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +# ... +``` + +### ACI.dev Tools {#aci-tools} + +If you'd like to use tools from the [ACI.dev tool library](https://www.aci.dev/tools) with Pydantic AI, you can use the [`ACIToolset`][pydantic_ai.ext.aci.ACIToolset] [toolset](toolsets.md) which takes a list of ACI tool names as well as the `linked_account_owner_id`. Note that Pydantic AI will not validate the arguments in this case -- it's up to the model to provide arguments matching the schema specified by the ACI tool, and up to the ACI tool to raise an error if the arguments are invalid. + +You will need to install the `aci-sdk` package, set your ACI API key in the `ACI_API_KEY` environment variable, and pass your ACI "linked account owner ID" to the function. + +```python {test="skip"} +import os + +from pydantic_ai import Agent +from pydantic_ai.ext.aci import ACIToolset + + +toolset = ACIToolset( + [ + 'OPEN_WEATHER_MAP__CURRENT_WEATHER', + 'OPEN_WEATHER_MAP__FORECAST', + ], + linked_account_owner_id=os.getenv('LINKED_ACCOUNT_OWNER_ID'), +) + +agent = Agent('openai:gpt-4o', toolsets=[toolset]) +``` diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index 360ca23471..edd84ddb88 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -52,11 +52,11 @@ server = MCPServerStdio('deno', 'jsr:@pydantic/mcp-run-python', 'stdio', ]) -agent = Agent('claude-3-5-haiku-latest', mcp_servers=[server]) +agent = Agent('claude-3-5-haiku-latest', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('How many days between 2000-01-01 and 2025-03-18?') print(result.output) #> There are 9,208 days between January 1, 2000, and March 18, 2025.w diff --git a/mkdocs.yml b/mkdocs.yml index a950d52c0c..fc6cd27999 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -28,9 +28,10 @@ nav: - models/cohere.md - models/groq.md - models/mistral.md + - models/huggingface.md - dependencies.md - tools.md - - common-tools.md + - toolsets.md - output.md - message-history.md - testing.md @@ -41,6 +42,7 @@ nav: - input.md - thinking.md - direct.md + - common-tools.md - MCP: - mcp/index.md - mcp/client.md @@ -64,6 +66,7 @@ nav: - API Reference: - api/agent.md - api/tools.md + - api/toolsets.md - api/common_tools.md - api/output.md - api/result.md @@ -75,6 +78,7 @@ nav: - api/format_as_xml.md - api/format_prompt.md - api/direct.md + - api/ext.md - api/models/base.md - api/models/openai.md - api/models/anthropic.md diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index fda19acda4..f6b4a51c3f 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -3,6 +3,7 @@ import asyncio import dataclasses import hashlib +from collections import defaultdict, deque from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import asynccontextmanager, contextmanager from contextvars import ContextVar @@ -13,17 +14,18 @@ from typing_extensions import TypeGuard, TypeVar, assert_never from pydantic_ai._function_schema import _takes_ctx as is_takes_ctx # type: ignore +from pydantic_ai._tool_manager import ToolManager from pydantic_ai._utils import is_async_callable, run_in_executor from pydantic_graph import BaseNode, Graph, GraphRunContext from pydantic_graph.nodes import End, NodeRunEndT from . import _output, _system_prompt, exceptions, messages as _messages, models, result, usage as _usage +from .exceptions import ToolRetryError from .output import OutputDataT, OutputSpec from .settings import ModelSettings, merge_model_settings -from .tools import RunContext, Tool, ToolDefinition, ToolsPrepareFunc +from .tools import RunContext, ToolDefinition, ToolKind if TYPE_CHECKING: - from .mcp import MCPServer from .models.instrumented import InstrumentationSettings __all__ = ( @@ -77,11 +79,13 @@ class GraphAgentState: retries: int run_step: int - def increment_retries(self, max_result_retries: int, error: Exception | None = None) -> None: + def increment_retries(self, max_result_retries: int, error: BaseException | None = None) -> None: self.retries += 1 if self.retries > max_result_retries: - message = f'Exceeded maximum retries ({max_result_retries}) for result validation' + message = f'Exceeded maximum retries ({max_result_retries}) for output validation' if error: + if isinstance(error, exceptions.UnexpectedModelBehavior) and error.__cause__ is not None: + error = error.__cause__ raise exceptions.UnexpectedModelBehavior(message) from error else: raise exceptions.UnexpectedModelBehavior(message) @@ -108,15 +112,11 @@ class GraphAgentDeps(Generic[DepsT, OutputDataT]): history_processors: Sequence[HistoryProcessor[DepsT]] - function_tools: dict[str, Tool[DepsT]] = dataclasses.field(repr=False) - mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - default_retries: int + tool_manager: ToolManager[DepsT] tracer: Tracer instrumentation_settings: InstrumentationSettings | None = None - prepare_tools: ToolsPrepareFunc[DepsT] | None = None - class AgentNode(BaseNode[GraphAgentState, GraphAgentDeps[DepsT, Any], result.FinalResult[NodeRunEndT]]): """The base class for all agent nodes. @@ -248,59 +248,27 @@ async def _prepare_request_parameters( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], ) -> models.ModelRequestParameters: """Build tools and create an agent model.""" - function_tool_defs_map: dict[str, ToolDefinition] = {} - run_context = build_run_context(ctx) - - async def add_tool(tool: Tool[DepsT]) -> None: - ctx = run_context.replace_with(retry=tool.current_retry, tool_name=tool.name) - if tool_def := await tool.prepare_tool_def(ctx): - # prepare_tool_def may change tool_def.name - if tool_def.name in function_tool_defs_map: - if tool_def.name != tool.name: - # Prepare tool def may have renamed the tool - raise exceptions.UserError( - f"Renaming tool '{tool.name}' to '{tool_def.name}' conflicts with existing tool." - ) - else: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}.') - function_tool_defs_map[tool_def.name] = tool_def - - async def add_mcp_server_tools(server: MCPServer) -> None: - if not server.is_running: - raise exceptions.UserError(f'MCP server is not running: {server}') - tool_defs = await server.list_tools() - for tool_def in tool_defs: - if tool_def.name in function_tool_defs_map: - raise exceptions.UserError( - f"MCP Server '{server}' defines a tool whose name conflicts with existing tool: {tool_def.name!r}. Consider using `tool_prefix` to avoid name conflicts." - ) - function_tool_defs_map[tool_def.name] = tool_def - - await asyncio.gather( - *map(add_tool, ctx.deps.function_tools.values()), - *map(add_mcp_server_tools, ctx.deps.mcp_servers), - ) - function_tool_defs = list(function_tool_defs_map.values()) - if ctx.deps.prepare_tools: - # Prepare the tools using the provided function - # This also acts over tool definitions pulled from MCP servers - function_tool_defs = await ctx.deps.prepare_tools(run_context, function_tool_defs) or [] + ctx.deps.tool_manager = await ctx.deps.tool_manager.for_run_step(run_context) output_schema = ctx.deps.output_schema - - output_tools = [] output_object = None - if isinstance(output_schema, _output.ToolOutputSchema): - output_tools = output_schema.tool_defs() - elif isinstance(output_schema, _output.NativeOutputSchema): + if isinstance(output_schema, _output.NativeOutputSchema): output_object = output_schema.object_def # ToolOrTextOutputSchema, NativeOutputSchema, and PromptedOutputSchema all inherit from TextOutputSchema allow_text_output = isinstance(output_schema, _output.TextOutputSchema) + function_tools: list[ToolDefinition] = [] + output_tools: list[ToolDefinition] = [] + for tool_def in ctx.deps.tool_manager.tool_defs: + if tool_def.kind == 'output': + output_tools.append(tool_def) + else: + function_tools.append(tool_def) + return models.ModelRequestParameters( - function_tools=function_tool_defs, + function_tools=function_tools, output_mode=output_schema.mode, output_tools=output_tools, output_object=output_object, @@ -341,8 +309,8 @@ async def stream( ctx.deps.output_schema, ctx.deps.output_validators, build_run_context(ctx), - _output.build_trace_context(ctx), ctx.deps.usage_limits, + ctx.deps.tool_manager, ) yield agent_stream # In case the user didn't manually consume the full stream, ensure it is fully consumed here, @@ -438,7 +406,6 @@ class CallToolsNode(AgentNode[DepsT, NodeRunEndT]): _next_node: ModelRequestNode[DepsT, NodeRunEndT] | End[result.FinalResult[NodeRunEndT]] | None = field( default=None, repr=False ) - _tool_responses: list[_messages.ModelRequestPart] = field(default_factory=list, repr=False) async def run( self, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]] @@ -520,47 +487,30 @@ async def _handle_tool_calls( ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], tool_calls: list[_messages.ToolCallPart], ) -> AsyncIterator[_messages.HandleResponseEvent]: - output_schema = ctx.deps.output_schema run_context = build_run_context(ctx) - final_result: result.FinalResult[NodeRunEndT] | None = None - parts: list[_messages.ModelRequestPart] = [] - - # first, look for the output tool call - if isinstance(output_schema, _output.ToolOutputSchema): - for call, output_tool in output_schema.find_tool(tool_calls): - try: - trace_context = _output.build_trace_context(ctx) - result_data = await output_tool.process(call, run_context, trace_context) - result_data = await _validate_output(result_data, ctx, call) - except _output.ToolRetryError as e: - # TODO: Should only increment retry stuff once per node execution, not for each tool call - # Also, should increment the tool-specific retry count rather than the run retry count - ctx.state.increment_retries(ctx.deps.max_result_retries, e) - parts.append(e.tool_retry) - else: - final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - break + output_parts: list[_messages.ModelRequestPart] = [] + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1) - # Then build the other request parts based on end strategy - tool_responses: list[_messages.ModelRequestPart] = self._tool_responses async for event in process_function_tools( - tool_calls, - final_result and final_result.tool_name, - final_result and final_result.tool_call_id, - ctx, - tool_responses, + ctx.deps.tool_manager, tool_calls, None, ctx, output_parts, output_final_result ): yield event - if final_result: - self._next_node = self._handle_final_result(ctx, final_result, tool_responses) + if output_final_result: + final_result = output_final_result[0] + self._next_node = self._handle_final_result(ctx, final_result, output_parts) + elif deferred_tool_calls := ctx.deps.tool_manager.get_deferred_tool_calls(tool_calls): + if not ctx.deps.output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + final_result = result.FinalResult(cast(NodeRunEndT, deferred_tool_calls), None, None) + self._next_node = self._handle_final_result(ctx, final_result, output_parts) else: - if tool_responses: - parts.extend(tool_responses) instructions = await ctx.deps.get_instructions(run_context) self._next_node = ModelRequestNode[DepsT, NodeRunEndT]( - _messages.ModelRequest(parts=parts, instructions=instructions) + _messages.ModelRequest(parts=output_parts, instructions=instructions) ) def _handle_final_result( @@ -586,18 +536,18 @@ async def _handle_text_response( text = '\n\n'.join(texts) try: + run_context = build_run_context(ctx) if isinstance(output_schema, _output.TextOutputSchema): - run_context = build_run_context(ctx) - trace_context = _output.build_trace_context(ctx) - result_data = await output_schema.process(text, run_context, trace_context) + result_data = await output_schema.process(text, run_context) else: m = _messages.RetryPromptPart( content='Plain text responses are not permitted, please include your response in a tool call', ) - raise _output.ToolRetryError(m) + raise ToolRetryError(m) - result_data = await _validate_output(result_data, ctx, None) - except _output.ToolRetryError as e: + for validator in ctx.deps.output_validators: + result_data = await validator.validate(result_data, run_context) + except ToolRetryError as e: ctx.state.increment_retries(ctx.deps.max_result_retries, e) return ModelRequestNode[DepsT, NodeRunEndT](_messages.ModelRequest(parts=[e.tool_retry])) else: @@ -612,6 +562,9 @@ def build_run_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT usage=ctx.state.usage, prompt=ctx.deps.prompt, messages=ctx.state.message_history, + tracer=ctx.deps.tracer, + trace_include_content=ctx.deps.instrumentation_settings is not None + and ctx.deps.instrumentation_settings.include_content, run_step=ctx.state.run_step, ) @@ -623,269 +576,210 @@ def multi_modal_content_identifier(identifier: str | bytes) -> str: return hashlib.sha1(identifier).hexdigest()[:6] -async def process_function_tools( # noqa C901 +async def process_function_tools( # noqa: C901 + tool_manager: ToolManager[DepsT], tool_calls: list[_messages.ToolCallPart], - output_tool_name: str | None, - output_tool_call_id: str | None, + final_result: result.FinalResult[NodeRunEndT] | None, ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], output_parts: list[_messages.ModelRequestPart], + output_final_result: deque[result.FinalResult[NodeRunEndT]] = deque(maxlen=1), ) -> AsyncIterator[_messages.HandleResponseEvent]: """Process function (i.e., non-result) tool calls in parallel. Also add stub return parts for any other tools that need it. - Because async iterators can't have return values, we use `output_parts` as an output argument. + Because async iterators can't have return values, we use `output_parts` and `output_final_result` as output arguments. """ - stub_function_tools = bool(output_tool_name) and ctx.deps.end_strategy == 'early' - output_schema = ctx.deps.output_schema - - # we rely on the fact that if we found a result, it's the first output tool in the last - found_used_output_tool = False - run_context = build_run_context(ctx) - - calls_to_run: list[tuple[Tool[DepsT], _messages.ToolCallPart]] = [] + tool_calls_by_kind: dict[ToolKind | Literal['unknown'], list[_messages.ToolCallPart]] = defaultdict(list) for call in tool_calls: - if ( - call.tool_name == output_tool_name - and call.tool_call_id == output_tool_call_id - and not found_used_output_tool - ): - found_used_output_tool = True - output_parts.append( - _messages.ToolReturnPart( + tool_def = tool_manager.get_tool_def(call.tool_name) + kind = tool_def.kind if tool_def else 'unknown' + tool_calls_by_kind[kind].append(call) + + # First, we handle output tool calls + for call in tool_calls_by_kind['output']: + if final_result: + if final_result.tool_call_id == call.tool_call_id: + part = _messages.ToolReturnPart( tool_name=call.tool_name, content='Final result processed.', tool_call_id=call.tool_call_id, ) - ) - elif tool := ctx.deps.function_tools.get(call.tool_name): - if stub_function_tools: - output_parts.append( - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) else: - event = _messages.FunctionToolCallEvent(call) - yield event - calls_to_run.append((tool, call)) - elif mcp_tool := await _tool_from_mcp_server(call.tool_name, ctx): - if stub_function_tools: - # TODO(Marcelo): We should add coverage for this part of the code. - output_parts.append( # pragma: no cover - _messages.ToolReturnPart( - tool_name=call.tool_name, - content='Tool not executed - a final result was already processed.', - tool_call_id=call.tool_call_id, - ) - ) - else: - event = _messages.FunctionToolCallEvent(call) - yield event - calls_to_run.append((mcp_tool, call)) - elif call.tool_name in output_schema.tools: - # if tool_name is in output_schema, it means we found a output tool but an error occurred in - # validation, we don't add another part here - if output_tool_name is not None: yield _messages.FunctionToolCallEvent(call) - if found_used_output_tool: - content = 'Output tool not used - a final result was already processed.' - else: - # TODO: Include information about the validation failure, and/or merge this with the ModelRetry part - content = 'Output tool not used - result failed validation.' part = _messages.ToolReturnPart( tool_name=call.tool_name, - content=content, + content='Output tool not used - a final result was already processed.', tool_call_id=call.tool_call_id, ) yield _messages.FunctionToolResultEvent(part) - output_parts.append(part) - else: - yield _messages.FunctionToolCallEvent(call) - part = _unknown_tool(call.tool_name, call.tool_call_id, ctx) - yield _messages.FunctionToolResultEvent(part) output_parts.append(part) + else: + try: + result_data = await tool_manager.handle_call(call) + except exceptions.UnexpectedModelBehavior as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + raise e # pragma: no cover + except ToolRetryError as e: + ctx.state.increment_retries(ctx.deps.max_result_retries, e) + yield _messages.FunctionToolCallEvent(call) + output_parts.append(e.tool_retry) + yield _messages.FunctionToolResultEvent(e.tool_retry) + else: + part = _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Final result processed.', + tool_call_id=call.tool_call_id, + ) + output_parts.append(part) + final_result = result.FinalResult(result_data, call.tool_name, call.tool_call_id) - if not calls_to_run: - return - - user_parts: list[_messages.UserPromptPart] = [] + # Then, we handle function tool calls + calls_to_run: list[_messages.ToolCallPart] = [] + if final_result and ctx.deps.end_strategy == 'early': + output_parts.extend( + [ + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + for call in tool_calls_by_kind['function'] + ] + ) + else: + calls_to_run.extend(tool_calls_by_kind['function']) - include_content = ( - ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content - ) + # Then, we handle unknown tool calls + if tool_calls_by_kind['unknown']: + ctx.state.increment_retries(ctx.deps.max_result_retries) + calls_to_run.extend(tool_calls_by_kind['unknown']) - # Run all tool tasks in parallel - results_by_index: dict[int, _messages.ModelRequestPart] = {} - with ctx.deps.tracer.start_as_current_span( - 'running tools', - attributes={ - 'tools': [call.tool_name for _, call in calls_to_run], - 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', - }, - ): - tasks = [ - asyncio.create_task(tool.run(call, run_context, ctx.deps.tracer, include_content), name=call.tool_name) - for tool, call in calls_to_run - ] - - pending = tasks - while pending: - done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) - for task in done: - index = tasks.index(task) - result = task.result() - yield _messages.FunctionToolResultEvent(result) - - if isinstance(result, _messages.RetryPromptPart): - results_by_index[index] = result - elif isinstance(result, _messages.ToolReturnPart): - if isinstance(result.content, _messages.ToolReturn): - tool_return = result.content - if ( - isinstance(tool_return.return_value, _messages.MultiModalContentTypes) - or isinstance(tool_return.return_value, list) - and any( - isinstance(content, _messages.MultiModalContentTypes) - for content in tool_return.return_value # type: ignore - ) - ): - raise exceptions.UserError( - f"{result.tool_name}'s `return_value` contains invalid nested MultiModalContentTypes objects. " - f'Please use `content` instead.' - ) - result.content = tool_return.return_value # type: ignore - result.metadata = tool_return.metadata - if tool_return.content: - user_parts.append( - _messages.UserPromptPart( - content=list(tool_return.content), - timestamp=result.timestamp, - part_kind='user-prompt', - ) - ) - contents: list[Any] - single_content: bool - if isinstance(result.content, list): - contents = result.content # type: ignore - single_content = False - else: - contents = [result.content] - single_content = True - - processed_contents: list[Any] = [] - for content in contents: - if isinstance(content, _messages.ToolReturn): - raise exceptions.UserError( - f"{result.tool_name}'s return contains invalid nested ToolReturn objects. " - f'ToolReturn should be used directly.' - ) - elif isinstance(content, _messages.MultiModalContentTypes): - # Handle direct multimodal content - if isinstance(content, _messages.BinaryContent): - identifier = multi_modal_content_identifier(content.data) - else: - identifier = multi_modal_content_identifier(content.url) - - user_parts.append( - _messages.UserPromptPart( - content=[f'This is file {identifier}:', content], - timestamp=result.timestamp, - part_kind='user-prompt', - ) - ) - processed_contents.append(f'See file {identifier}') - else: - # Handle regular content - processed_contents.append(content) - - if single_content: - result.content = processed_contents[0] - else: - result.content = processed_contents + for call in calls_to_run: + yield _messages.FunctionToolCallEvent(call) - results_by_index[index] = result - else: - assert_never(result) + user_parts: list[_messages.UserPromptPart] = [] - # We append the results at the end, rather than as they are received, to retain a consistent ordering - # This is mostly just to simplify testing - for k in sorted(results_by_index): - output_parts.append(results_by_index[k]) + if calls_to_run: + # Run all tool tasks in parallel + parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} + with ctx.deps.tracer.start_as_current_span( + 'running tools', + attributes={ + 'tools': [call.tool_name for call in calls_to_run], + 'logfire.msg': f'running {len(calls_to_run)} tool{"" if len(calls_to_run) == 1 else "s"}', + }, + ): + tasks = [ + asyncio.create_task(_call_function_tool(tool_manager, call), name=call.tool_name) + for call in calls_to_run + ] + + pending = tasks + while pending: + done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) + for task in done: + index = tasks.index(task) + tool_result_part, extra_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_result_part) + + parts_by_index[index] = [tool_result_part, *extra_parts] + + # We append the results at the end, rather than as they are received, to retain a consistent ordering + # This is mostly just to simplify testing + for k in sorted(parts_by_index): + output_parts.extend(parts_by_index[k]) + + # Finally, we handle deferred tool calls + for call in tool_calls_by_kind['deferred']: + if final_result: + output_parts.append( + _messages.ToolReturnPart( + tool_name=call.tool_name, + content='Tool not executed - a final result was already processed.', + tool_call_id=call.tool_call_id, + ) + ) + else: + yield _messages.FunctionToolCallEvent(call) output_parts.extend(user_parts) + if final_result: + output_final_result.append(final_result) -async def _tool_from_mcp_server( - tool_name: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> Tool[DepsT] | None: - """Call each MCP server to find the tool with the given name. - - Args: - tool_name: The name of the tool to find. - ctx: The current run context. - Returns: - The tool with the given name, or `None` if no tool with the given name is found. - """ - - async def run_tool(ctx: RunContext[DepsT], **args: Any) -> Any: - # There's no normal situation where the server will not be running at this point, we check just in case - # some weird edge case occurs. - if not server.is_running: # pragma: no cover - raise exceptions.UserError(f'MCP server is not running: {server}') - - if server.process_tool_call is not None: - result = await server.process_tool_call(ctx, server.call_tool, tool_name, args) - else: - result = await server.call_tool(tool_name, args) - - return result - - for server in ctx.deps.mcp_servers: - tools = await server.list_tools() - if tool_name in {tool.name for tool in tools}: # pragma: no branch - return Tool(name=tool_name, function=run_tool, takes_ctx=True, max_retries=ctx.deps.default_retries) - return None +async def _call_function_tool( + tool_manager: ToolManager[DepsT], + tool_call: _messages.ToolCallPart, +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: + try: + tool_result = await tool_manager.handle_call(tool_call) + except ToolRetryError as e: + return (e.tool_retry, []) + + part = _messages.ToolReturnPart( + tool_name=tool_call.tool_name, + content=tool_result, + tool_call_id=tool_call.tool_call_id, + ) + extra_parts: list[_messages.ModelRequestPart] = [] + if isinstance(tool_result, _messages.ToolReturn): + if ( + isinstance(tool_result.return_value, _messages.MultiModalContentTypes) + or isinstance(tool_result.return_value, list) + and any( + isinstance(content, _messages.MultiModalContentTypes) + for content in tool_result.return_value # type: ignore + ) + ): + raise exceptions.UserError( + f'The `return_value` of tool {tool_call.tool_name!r} contains invalid nested `MultiModalContentTypes` objects. ' + f'Please use `content` instead.' + ) -def _unknown_tool( - tool_name: str, - tool_call_id: str, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, NodeRunEndT]], -) -> _messages.RetryPromptPart: - ctx.state.increment_retries(ctx.deps.max_result_retries) - tool_names = list(ctx.deps.function_tools.keys()) + part.content = tool_result.return_value # type: ignore + part.metadata = tool_result.metadata + if tool_result.content: + extra_parts.append( + _messages.UserPromptPart( + content=list(tool_result.content), + part_kind='user-prompt', + ) + ) + else: - output_schema = ctx.deps.output_schema - if isinstance(output_schema, _output.ToolOutputSchema): - tool_names.extend(output_schema.tool_names()) + def process_content(content: Any) -> Any: + if isinstance(content, _messages.ToolReturn): + raise exceptions.UserError( + f'The return value of tool {tool_call.tool_name!r} contains invalid nested `ToolReturn` objects. ' + f'`ToolReturn` should be used directly.' + ) + elif isinstance(content, _messages.MultiModalContentTypes): + if isinstance(content, _messages.BinaryContent): + identifier = multi_modal_content_identifier(content.data) + else: + identifier = multi_modal_content_identifier(content.url) - if tool_names: - msg = f'Available tools: {", ".join(tool_names)}' - else: - msg = 'No tools available.' + extra_parts.append( + _messages.UserPromptPart( + content=[f'This is file {identifier}:', content], + part_kind='user-prompt', + ) + ) + return f'See file {identifier}' - return _messages.RetryPromptPart( - tool_name=tool_name, - tool_call_id=tool_call_id, - content=f'Unknown tool name: {tool_name!r}. {msg}', - ) + return content + if isinstance(tool_result, list): + contents = cast(list[Any], tool_result) + part.content = [process_content(content) for content in contents] + else: + part.content = process_content(tool_result) -async def _validate_output( - result_data: T, - ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, T]], - tool_call: _messages.ToolCallPart | None, -) -> T: - for validator in ctx.deps.output_validators: - run_context = build_run_context(ctx) - result_data = await validator.validate(result_data, tool_call, run_context) - return result_data + return (part, extra_parts) @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/_output.py b/pydantic_ai_slim/pydantic_ai/_output.py index c3199dd95c..c8925b0678 100644 --- a/pydantic_ai_slim/pydantic_ai/_output.py +++ b/pydantic_ai_slim/pydantic_ai/_output.py @@ -1,24 +1,21 @@ from __future__ import annotations as _annotations -import dataclasses import inspect import json from abc import ABC, abstractmethod -from collections.abc import Awaitable, Iterable, Iterator, Sequence +from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Union, cast, overload -from opentelemetry.trace import Tracer from pydantic import TypeAdapter, ValidationError -from pydantic_core import SchemaValidator -from typing_extensions import TypedDict, TypeVar, assert_never - -from pydantic_graph.nodes import GraphRunContext +from pydantic_core import SchemaValidator, to_json +from typing_extensions import Self, TypedDict, TypeVar, assert_never from . import _function_schema, _utils, messages as _messages from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UserError +from .exceptions import ModelRetry, ToolRetryError, UserError from .output import ( + DeferredToolCalls, NativeOutput, OutputDataT, OutputMode, @@ -29,12 +26,12 @@ TextOutput, TextOutputFunc, ToolOutput, + _OutputSpecItem, # type: ignore[reportPrivateUsage] ) from .tools import GenerateToolJsonSchema, ObjectJsonSchema, ToolDefinition +from .toolsets.abstract import AbstractToolset, ToolsetTool if TYPE_CHECKING: - from pydantic_ai._agent_graph import DepsT, GraphAgentDeps, GraphAgentState - from .profiles import ModelProfile T = TypeVar('T') @@ -72,77 +69,45 @@ DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation' -@dataclass(frozen=True) -class TraceContext: - """A context for tracing output processing.""" +async def execute_output_function_with_span( + function_schema: _function_schema.FunctionSchema, + run_context: RunContext[AgentDepsT], + args: dict[str, Any] | Any, +) -> Any: + """Execute a function call within a traced span, automatically recording the response.""" + # Set up span attributes + tool_name = run_context.tool_name or getattr(function_schema.function, '__name__', 'output_function') + attributes = { + 'gen_ai.tool.name': tool_name, + 'logfire.msg': f'running output function: {tool_name}', + } + if run_context.tool_call_id: + attributes['gen_ai.tool.call.id'] = run_context.tool_call_id + if run_context.trace_include_content: + attributes['tool_arguments'] = to_json(args).decode() + attributes['logfire.json_schema'] = json.dumps( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + }, + } + ) - tracer: Tracer - include_content: bool - call: _messages.ToolCallPart | None = None + with run_context.tracer.start_as_current_span('running output function', attributes=attributes) as span: + output = await function_schema.call(args, run_context) - def with_call(self, call: _messages.ToolCallPart): - return dataclasses.replace(self, call=call) + # Record response if content inclusion is enabled + if run_context.trace_include_content and span.is_recording(): + from .models.instrumented import InstrumentedModel - async def execute_function_with_span( - self, - function_schema: _function_schema.FunctionSchema, - run_context: RunContext[AgentDepsT], - args: dict[str, Any] | Any, - call: _messages.ToolCallPart, - include_tool_call_id: bool = True, - ) -> Any: - """Execute a function call within a traced span, automatically recording the response.""" - # Set up span attributes - attributes = { - 'gen_ai.tool.name': call.tool_name, - 'logfire.msg': f'running output function: {call.tool_name}', - } - if include_tool_call_id: - attributes['gen_ai.tool.call.id'] = call.tool_call_id - if self.include_content: - attributes['tool_arguments'] = call.args_as_json_str() - attributes['logfire.json_schema'] = json.dumps( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - }, - } + span.set_attribute( + 'tool_response', + output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), ) - # Execute function within span - with self.tracer.start_as_current_span('running output function', attributes=attributes) as span: - output = await function_schema.call(args, run_context) - - # Record response if content inclusion is enabled - if self.include_content and span.is_recording(): - from .models.instrumented import InstrumentedModel - - span.set_attribute( - 'tool_response', - output if isinstance(output, str) else json.dumps(InstrumentedModel.serialize_any(output)), - ) - - return output - - -def build_trace_context(ctx: GraphRunContext[GraphAgentState, GraphAgentDeps[DepsT, Any]]) -> TraceContext: - """Build a `TraceContext` from the current agent graph run context.""" - return TraceContext( - tracer=ctx.deps.tracer, - include_content=( - ctx.deps.instrumentation_settings is not None and ctx.deps.instrumentation_settings.include_content - ), - ) - - -class ToolRetryError(Exception): - """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" - - def __init__(self, tool_retry: _messages.RetryPromptPart): - self.tool_retry = tool_retry - super().__init__() + return output @dataclass @@ -158,23 +123,21 @@ def __post_init__(self): async def validate( self, result: T, - tool_call: _messages.ToolCallPart | None, run_context: RunContext[AgentDepsT], + wrap_validation_errors: bool = True, ) -> T: """Validate a result but calling the function. Args: result: The result data after Pydantic validation the message content. - tool_call: The original tool call message, `None` if there was no tool call. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. + wrap_validation_errors: If true, wrap the validation errors in a retry message. Returns: Result of either the validated result data (ok) or a retry message (Err). """ if self._takes_ctx: - ctx = run_context.replace_with(tool_name=tool_call.tool_name if tool_call else None) - args = ctx, result + args = run_context, result else: args = (result,) @@ -186,24 +149,32 @@ async def validate( function = cast(Callable[[Any], T], self.function) result_data = await _utils.run_in_executor(function, *args) except ModelRetry as r: - m = _messages.RetryPromptPart(content=r.message) - if tool_call is not None: - m.tool_name = tool_call.tool_name - m.tool_call_id = tool_call.tool_call_id - raise ToolRetryError(m) from r + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + tool_name=run_context.tool_name, + ) + if run_context.tool_call_id: # pragma: no cover + m.tool_call_id = run_context.tool_call_id + raise ToolRetryError(m) from r + else: + raise r else: return result_data +@dataclass class BaseOutputSchema(ABC, Generic[OutputDataT]): + allows_deferred_tool_calls: bool + @abstractmethod def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: raise NotImplementedError() @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return {} + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return None @dataclass(init=False) @@ -235,7 +206,7 @@ def build( ) -> BaseOutputSchema[OutputDataT]: ... @classmethod - def build( + def build( # noqa: C901 cls, output_spec: OutputSpec[OutputDataT], *, @@ -245,117 +216,93 @@ def build( strict: bool | None = None, ) -> BaseOutputSchema[OutputDataT]: """Build an OutputSchema dataclass from an output type.""" - if output_spec is str: - return PlainTextOutputSchema() + raw_outputs = _flatten_output_spec(output_spec) + + outputs = [output for output in raw_outputs if output is not DeferredToolCalls] + allows_deferred_tool_calls = len(outputs) < len(raw_outputs) + if len(outputs) == 0 and allows_deferred_tool_calls: + raise UserError('At least one output type must be provided other than `DeferredToolCalls`.') + + if output := next((output for output in outputs if isinstance(output, NativeOutput)), None): + if len(outputs) > 1: + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover - if isinstance(output_spec, NativeOutput): return NativeOutputSchema( - cls._build_processor( - _flatten_output_spec(output_spec.outputs), - name=output_spec.name, - description=output_spec.description, - strict=output_spec.strict, - ) + processor=cls._build_processor( + _flatten_output_spec(output.outputs), + name=output.name, + description=output.description, + strict=output.strict, + ), + allows_deferred_tool_calls=allows_deferred_tool_calls, ) - elif isinstance(output_spec, PromptedOutput): + elif output := next((output for output in outputs if isinstance(output, PromptedOutput)), None): + if len(outputs) > 1: + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover + return PromptedOutputSchema( - cls._build_processor( - _flatten_output_spec(output_spec.outputs), - name=output_spec.name, - description=output_spec.description, + processor=cls._build_processor( + _flatten_output_spec(output.outputs), + name=output.name, + description=output.description, ), - template=output_spec.template, + template=output.template, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) text_outputs: Sequence[type[str] | TextOutput[OutputDataT]] = [] tool_outputs: Sequence[ToolOutput[OutputDataT]] = [] other_outputs: Sequence[OutputTypeOrFunction[OutputDataT]] = [] - for output in _flatten_output_spec(output_spec): + for output in outputs: if output is str: text_outputs.append(cast(type[str], output)) elif isinstance(output, TextOutput): text_outputs.append(output) elif isinstance(output, ToolOutput): tool_outputs.append(output) + elif isinstance(output, NativeOutput): + # We can never get here because this is checked for above. + raise UserError('`NativeOutput` must be the only output type.') # pragma: no cover + elif isinstance(output, PromptedOutput): + # We can never get here because this is checked for above. + raise UserError('`PromptedOutput` must be the only output type.') # pragma: no cover else: other_outputs.append(output) - tools = cls._build_tools(tool_outputs + other_outputs, name=name, description=description, strict=strict) + toolset = OutputToolset.build(tool_outputs + other_outputs, name=name, description=description, strict=strict) if len(text_outputs) > 0: if len(text_outputs) > 1: - raise UserError('Only one text output is allowed.') + raise UserError('Only one `str` or `TextOutput` is allowed.') text_output = text_outputs[0] text_output_schema = None if isinstance(text_output, TextOutput): text_output_schema = PlainTextOutputProcessor(text_output.output_function) - if len(tools) == 0: - return PlainTextOutputSchema(text_output_schema) + if toolset: + return ToolOrTextOutputSchema( + processor=text_output_schema, toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls + ) else: - return ToolOrTextOutputSchema(processor=text_output_schema, tools=tools) + return PlainTextOutputSchema( + processor=text_output_schema, allows_deferred_tool_calls=allows_deferred_tool_calls + ) if len(tool_outputs) > 0: - return ToolOutputSchema(tools) + return ToolOutputSchema(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) if len(other_outputs) > 0: schema = OutputSchemaWithoutMode( processor=cls._build_processor(other_outputs, name=name, description=description, strict=strict), - tools=tools, + toolset=toolset, + allows_deferred_tool_calls=allows_deferred_tool_calls, ) if default_mode: schema = schema.with_default_mode(default_mode) return schema - raise UserError('No output type provided.') # pragma: no cover - - @staticmethod - def _build_tools( - outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], - name: str | None = None, - description: str | None = None, - strict: bool | None = None, - ) -> dict[str, OutputTool[OutputDataT]]: - tools: dict[str, OutputTool[OutputDataT]] = {} - - default_name = name or DEFAULT_OUTPUT_TOOL_NAME - default_description = description - default_strict = strict - - multiple = len(outputs) > 1 - for output in outputs: - name = None - description = None - strict = None - if isinstance(output, ToolOutput): - # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads - name = output.name - description = output.description - strict = output.strict - - output = output.output - - description = description or default_description - if strict is None: - strict = default_strict - - processor = ObjectOutputProcessor(output=output, description=description, strict=strict) - - if name is None: - name = default_name - if multiple: - name += f'_{processor.object_def.name}' - - i = 1 - original_name = name - while name in tools: - i += 1 - name = f'{original_name}_{i}' - - tools[name] = OutputTool(name=name, processor=processor, multiple=multiple) - - return tools + raise UserError('At least one output type must be provided.') @staticmethod def _build_processor( @@ -387,32 +334,39 @@ def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDa @dataclass(init=False) class OutputSchemaWithoutMode(BaseOutputSchema[OutputDataT]): processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT] - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None def __init__( self, processor: ObjectOutputProcessor[OutputDataT] | UnionOutputProcessor[OutputDataT], - tools: dict[str, OutputTool[OutputDataT]], + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): + super().__init__(allows_deferred_tool_calls) self.processor = processor - self._tools = tools + self._toolset = toolset def with_default_mode(self, mode: StructuredOutputMode) -> OutputSchema[OutputDataT]: if mode == 'native': - return NativeOutputSchema(self.processor) + return NativeOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'prompted': - return PromptedOutputSchema(self.processor) + return PromptedOutputSchema( + processor=self.processor, allows_deferred_tool_calls=self.allows_deferred_tool_calls + ) elif mode == 'tool': - return ToolOutputSchema(self.tools) + return ToolOutputSchema(toolset=self.toolset, allows_deferred_tool_calls=self.allows_deferred_tool_calls) else: assert_never(mode) @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - # We return tools here as they're checked in Agent._register_tool. - # At that point we may don't know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time. - return self._tools + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + # We return a toolset here as they're checked for name conflicts with other toolsets in the Agent constructor. + # At that point we may not know yet what output mode we're going to use if no model was provided or it was deferred until agent.run time, + # but we cover ourselves just in case we end up using the tool output mode. + return self._toolset class TextOutputSchema(OutputSchema[OutputDataT], ABC): @@ -421,7 +375,6 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -444,7 +397,6 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -453,7 +405,6 @@ async def process( Args: text: The output text to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -464,7 +415,7 @@ async def process( return cast(OutputDataT, text) return await self.processor.process( - text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -486,13 +437,12 @@ def mode(self) -> OutputMode: def raise_if_unsupported(self, profile: ModelProfile) -> None: """Raise an error if the mode is not supported by the model.""" if not profile.supports_json_schema_output: - raise UserError('Structured output is not supported by the model.') + raise UserError('Native structured output is not supported by the model.') async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -501,7 +451,6 @@ async def process( Args: text: The output text to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -509,7 +458,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ return await self.processor.process( - text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -545,7 +494,6 @@ async def process( self, text: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -554,7 +502,6 @@ async def process( Args: text: The output text to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -564,16 +511,17 @@ async def process( text = _utils.strip_markdown_fences(text) return await self.processor.process( - text, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + text, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @dataclass(init=False) class ToolOutputSchema(OutputSchema[OutputDataT]): - _tools: dict[str, OutputTool[OutputDataT]] = field(default_factory=dict) + _toolset: OutputToolset[Any] | None - def __init__(self, tools: dict[str, OutputTool[OutputDataT]]): - self._tools = tools + def __init__(self, toolset: OutputToolset[Any] | None, allows_deferred_tool_calls: bool): + super().__init__(allows_deferred_tool_calls) + self._toolset = toolset @property def mode(self) -> OutputMode: @@ -585,36 +533,9 @@ def raise_if_unsupported(self, profile: ModelProfile) -> None: raise UserError('Output tools are not supported by the model.') @property - def tools(self) -> dict[str, OutputTool[OutputDataT]]: - """Get the tools for this output schema.""" - return self._tools - - def tool_names(self) -> list[str]: - """Return the names of the tools.""" - return list(self.tools.keys()) - - def tool_defs(self) -> list[ToolDefinition]: - """Get tool definitions to register with the model.""" - return [t.tool_def for t in self.tools.values()] - - def find_named_tool( - self, parts: Iterable[_messages.ModelResponsePart], tool_name: str - ) -> tuple[_messages.ToolCallPart, OutputTool[OutputDataT]] | None: - """Find a tool that matches one of the calls, with a specific name.""" - for part in parts: # pragma: no branch - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if part.tool_name == tool_name: - return part, self.tools[tool_name] - - def find_tool( - self, - parts: Iterable[_messages.ModelResponsePart], - ) -> Iterator[tuple[_messages.ToolCallPart, OutputTool[OutputDataT]]]: - """Find a tool that matches one of the calls.""" - for part in parts: - if isinstance(part, _messages.ToolCallPart): # pragma: no branch - if result := self.tools.get(part.tool_name): - yield part, result + def toolset(self) -> OutputToolset[Any] | None: + """Get the toolset for this output schema.""" + return self._toolset @dataclass(init=False) @@ -622,10 +543,11 @@ class ToolOrTextOutputSchema(ToolOutputSchema[OutputDataT], PlainTextOutputSchem def __init__( self, processor: PlainTextOutputProcessor[OutputDataT] | None, - tools: dict[str, OutputTool[OutputDataT]], + toolset: OutputToolset[Any] | None, + allows_deferred_tool_calls: bool, ): + super().__init__(toolset=toolset, allows_deferred_tool_calls=allows_deferred_tool_calls) self.processor = processor - self._tools = tools @property def mode(self) -> OutputMode: @@ -647,7 +569,6 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -659,7 +580,7 @@ async def process( class ObjectOutputProcessor(BaseOutputProcessor[OutputDataT]): object_def: OutputObjectDefinition outer_typed_dict_key: str | None = None - _validator: SchemaValidator + validator: SchemaValidator _function_schema: _function_schema.FunctionSchema | None = None def __init__( @@ -672,7 +593,7 @@ def __init__( ): if inspect.isfunction(output) or inspect.ismethod(output): self._function_schema = _function_schema.function_schema(output, GenerateToolJsonSchema) - self._validator = self._function_schema.validator + self.validator = self._function_schema.validator json_schema = self._function_schema.json_schema json_schema['description'] = self._function_schema.description else: @@ -688,7 +609,7 @@ def __init__( type_adapter = TypeAdapter(response_data_typed_dict) # Really a PluggableSchemaValidator, but it's API-compatible - self._validator = cast(SchemaValidator, type_adapter.validator) + self.validator = cast(SchemaValidator, type_adapter.validator) json_schema = _utils.check_object_json_schema( type_adapter.json_schema(schema_generator=GenerateToolJsonSchema) ) @@ -717,7 +638,6 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: @@ -726,7 +646,6 @@ async def process( Args: data: The output data to validate. run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. allow_partial: If true, allow partial validation. wrap_validation_errors: If true, wrap the validation errors in a retry message. @@ -734,11 +653,7 @@ async def process( Either the validated output data (left) or a retry message (right). """ try: - pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' - if isinstance(data, str): - output = self._validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) - else: - output = self._validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + output = self.validate(data, allow_partial) except ValidationError as e: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -748,30 +663,40 @@ async def process( else: raise + try: + output = await self.call(output, run_context) + except ModelRetry as r: + if wrap_validation_errors: + m = _messages.RetryPromptPart( + content=r.message, + ) + raise ToolRetryError(m) from r + else: + raise # pragma: no cover + + return output + + def validate( + self, + data: str | dict[str, Any] | None, + allow_partial: bool = False, + ) -> dict[str, Any]: + pyd_allow_partial: Literal['off', 'trailing-strings'] = 'trailing-strings' if allow_partial else 'off' + if isinstance(data, str): + return self.validator.validate_json(data or '{}', allow_partial=pyd_allow_partial) + else: + return self.validator.validate_python(data or {}, allow_partial=pyd_allow_partial) + + async def call( + self, + output: Any, + run_context: RunContext[AgentDepsT], + ): if k := self.outer_typed_dict_key: output = output[k] if self._function_schema: - # Wraps the output function call in an OpenTelemetry span. - if trace_context.call: - call = trace_context.call - include_tool_call_id = True - else: - function_name = getattr(self._function_schema.function, '__name__', 'output_function') - call = _messages.ToolCallPart(tool_name=function_name, args=data) - include_tool_call_id = False - try: - output = await trace_context.execute_function_with_span( - self._function_schema, run_context, output, call, include_tool_call_id - ) - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - content=r.message, - ) - raise ToolRetryError(m) from r - else: - raise + output = await execute_output_function_with_span(self._function_schema, run_context, output) return output @@ -876,12 +801,11 @@ async def process( self, data: str | dict[str, Any] | None, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: union_object = await self._union_processor.process( - data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) result = union_object.result @@ -897,7 +821,7 @@ async def process( raise return await processor.process( - data, run_context, trace_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors + data, run_context, allow_partial=allow_partial, wrap_validation_errors=wrap_validation_errors ) @@ -928,20 +852,12 @@ async def process( self, data: str, run_context: RunContext[AgentDepsT], - trace_context: TraceContext, allow_partial: bool = False, wrap_validation_errors: bool = True, ) -> OutputDataT: args = {self._str_argument_name: data} - # Wraps the output function call in an OpenTelemetry span. - # Note: PlainTextOutputProcessor is used for text responses (not tool calls), - # so we don't have tool call attributes like gen_ai.tool.name or gen_ai.tool.call.id - function_name = getattr(self._function_schema.function, '__name__', 'text_output_function') - call = _messages.ToolCallPart(tool_name=function_name, args=args) try: - output = await trace_context.execute_function_with_span( - self._function_schema, run_context, args, call, include_tool_call_id=False - ) + output = await execute_output_function_with_span(self._function_schema, run_context, args) except ModelRetry as r: if wrap_validation_errors: m = _messages.RetryPromptPart( @@ -955,91 +871,139 @@ async def process( @dataclass(init=False) -class OutputTool(Generic[OutputDataT]): - processor: ObjectOutputProcessor[OutputDataT] - tool_def: ToolDefinition +class OutputToolset(AbstractToolset[AgentDepsT]): + """A toolset that contains contains output tools for agent output types.""" - def __init__(self, *, name: str, processor: ObjectOutputProcessor[OutputDataT], multiple: bool): - self.processor = processor - object_def = processor.object_def + _tool_defs: list[ToolDefinition] + """The tool definitions for the output tools in this toolset.""" + processors: dict[str, ObjectOutputProcessor[Any]] + """The processors for the output tools in this toolset.""" + max_retries: int + output_validators: list[OutputValidator[AgentDepsT, Any]] - description = object_def.description - if not description: - description = DEFAULT_OUTPUT_TOOL_DESCRIPTION - if multiple: - description = f'{object_def.name}: {description}' + @classmethod + def build( + cls, + outputs: list[OutputTypeOrFunction[OutputDataT] | ToolOutput[OutputDataT]], + name: str | None = None, + description: str | None = None, + strict: bool | None = None, + ) -> Self | None: + if len(outputs) == 0: + return None - self.tool_def = ToolDefinition( - name=name, - description=description, - parameters_json_schema=object_def.json_schema, - strict=object_def.strict, - outer_typed_dict_key=processor.outer_typed_dict_key, - ) + processors: dict[str, ObjectOutputProcessor[Any]] = {} + tool_defs: list[ToolDefinition] = [] - async def process( - self, - tool_call: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - trace_context: TraceContext, - allow_partial: bool = False, - wrap_validation_errors: bool = True, - ) -> OutputDataT: - """Process an output message. + default_name = name or DEFAULT_OUTPUT_TOOL_NAME + default_description = description + default_strict = strict - Args: - tool_call: The tool call from the LLM to validate. - run_context: The current run context. - trace_context: The trace context to use for tracing the output processing. - allow_partial: If true, allow partial validation. - wrap_validation_errors: If true, wrap the validation errors in a retry message. + multiple = len(outputs) > 1 + for output in outputs: + name = None + description = None + strict = None + if isinstance(output, ToolOutput): + # do we need to error on conflicts here? (DavidM): If this is internal maybe doesn't matter, if public, use overloads + name = output.name + description = output.description + strict = output.strict - Returns: - Either the validated output data (left) or a retry message (right). - """ - try: - output = await self.processor.process( - tool_call.args, - run_context, - trace_context.with_call(tool_call), - allow_partial=allow_partial, - wrap_validation_errors=False, + output = output.output + + description = description or default_description + if strict is None: + strict = default_strict + + processor = ObjectOutputProcessor(output=output, description=description, strict=strict) + object_def = processor.object_def + + if name is None: + name = default_name + if multiple: + name += f'_{object_def.name}' + + i = 1 + original_name = name + while name in processors: + i += 1 + name = f'{original_name}_{i}' + + description = object_def.description + if not description: + description = DEFAULT_OUTPUT_TOOL_DESCRIPTION + if multiple: + description = f'{object_def.name}: {description}' + + tool_def = ToolDefinition( + name=name, + description=description, + parameters_json_schema=object_def.json_schema, + strict=object_def.strict, + outer_typed_dict_key=processor.outer_typed_dict_key, + kind='output', ) - except ValidationError as e: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from e - else: - raise # pragma: no cover - except ModelRetry as r: - if wrap_validation_errors: - m = _messages.RetryPromptPart( - tool_name=tool_call.tool_name, - content=r.message, - tool_call_id=tool_call.tool_call_id, - ) - raise ToolRetryError(m) from r - else: - raise # pragma: no cover - else: - return output + processors[name] = processor + tool_defs.append(tool_def) + + return cls(processors=processors, tool_defs=tool_defs) + + def __init__( + self, + tool_defs: list[ToolDefinition], + processors: dict[str, ObjectOutputProcessor[Any]], + max_retries: int = 1, + output_validators: list[OutputValidator[AgentDepsT, Any]] | None = None, + ): + self.processors = processors + self._tool_defs = tool_defs + self.max_retries = max_retries + self.output_validators = output_validators or [] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + tool_def.name: ToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=self.max_retries, + args_validator=self.processors[tool_def.name].validator, + ) + for tool_def in self._tool_defs + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + output = await self.processors[name].call(tool_args, ctx) + for validator in self.output_validators: + output = await validator.validate(output, ctx, wrap_validation_errors=False) + return output + + +@overload +def _flatten_output_spec( + output_spec: OutputTypeOrFunction[T] | Sequence[OutputTypeOrFunction[T]], +) -> Sequence[OutputTypeOrFunction[T]]: ... + + +@overload +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: ... -def _flatten_output_spec(output_spec: T | Sequence[T]) -> list[T]: - outputs: Sequence[T] +def _flatten_output_spec(output_spec: OutputSpec[T]) -> Sequence[_OutputSpecItem[T]]: + outputs: Sequence[OutputSpec[T]] if isinstance(output_spec, Sequence): outputs = output_spec else: outputs = (output_spec,) - outputs_flat: list[T] = [] + outputs_flat: list[_OutputSpecItem[T]] = [] for output in outputs: - if union_types := _utils.get_union_args(output): + if isinstance(output, Sequence): + outputs_flat.extend(_flatten_output_spec(cast(OutputSpec[T], output))) + elif union_types := _utils.get_union_args(output): outputs_flat.extend(union_types) else: - outputs_flat.append(output) + outputs_flat.append(cast(_OutputSpecItem[T], output)) return outputs_flat diff --git a/pydantic_ai_slim/pydantic_ai/_run_context.py b/pydantic_ai_slim/pydantic_ai/_run_context.py index bb7f474201..afad0e60e6 100644 --- a/pydantic_ai_slim/pydantic_ai/_run_context.py +++ b/pydantic_ai_slim/pydantic_ai/_run_context.py @@ -5,6 +5,7 @@ from dataclasses import field from typing import TYPE_CHECKING, Generic +from opentelemetry.trace import NoOpTracer, Tracer from typing_extensions import TypeVar from . import _utils, messages as _messages @@ -27,10 +28,16 @@ class RunContext(Generic[AgentDepsT]): """The model used in this run.""" usage: Usage """LLM usage associated with the run.""" - prompt: str | Sequence[_messages.UserContent] | None + prompt: str | Sequence[_messages.UserContent] | None = None """The original user prompt passed to the run.""" messages: list[_messages.ModelMessage] = field(default_factory=list) """Messages exchanged in the conversation so far.""" + tracer: Tracer = field(default_factory=NoOpTracer) + """The tracer to use for tracing the run.""" + trace_include_content: bool = False + """Whether to include the content of the messages in the trace.""" + retries: dict[str, int] = field(default_factory=dict) + """Number of retries for each tool so far.""" tool_call_id: str | None = None """The ID of the tool call.""" tool_name: str | None = None @@ -40,17 +47,4 @@ class RunContext(Generic[AgentDepsT]): run_step: int = 0 """The current step in the run.""" - def replace_with( - self, - retry: int | None = None, - tool_name: str | None | _utils.Unset = _utils.UNSET, - ) -> RunContext[AgentDepsT]: - # Create a new `RunContext` a new `retry` value and `tool_name`. - kwargs = {} - if retry is not None: - kwargs['retry'] = retry - if tool_name is not _utils.UNSET: # pragma: no branch - kwargs['tool_name'] = tool_name - return dataclasses.replace(self, **kwargs) - __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py new file mode 100644 index 0000000000..bea4103896 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -0,0 +1,190 @@ +from __future__ import annotations + +import json +from collections.abc import Iterable +from dataclasses import dataclass, replace +from typing import Any, Generic + +from pydantic import ValidationError +from typing_extensions import assert_never + +from pydantic_ai.output import DeferredToolCalls + +from . import messages as _messages +from ._run_context import AgentDepsT, RunContext +from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior +from .messages import ToolCallPart +from .tools import ToolDefinition +from .toolsets.abstract import AbstractToolset, ToolsetTool + + +@dataclass +class ToolManager(Generic[AgentDepsT]): + """Manages tools for an agent run step. It caches the agent run's toolset's tool definitions and handles calling tools and retries.""" + + ctx: RunContext[AgentDepsT] + """The agent run context for a specific run step.""" + toolset: AbstractToolset[AgentDepsT] + """The toolset that provides the tools for this run step.""" + tools: dict[str, ToolsetTool[AgentDepsT]] + """The cached tools for this run step.""" + + @classmethod + async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: + """Build a new tool manager for a specific run step.""" + return cls( + ctx=ctx, + toolset=toolset, + tools=await toolset.get_tools(ctx), + ) + + async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: + """Build a new tool manager for the next run step, carrying over the retries from the current run step.""" + return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries)) + + @property + def tool_defs(self) -> list[ToolDefinition]: + """The tool definitions for the tools in this tool manager.""" + return [tool.tool_def for tool in self.tools.values()] + + def get_tool_def(self, name: str) -> ToolDefinition | None: + """Get the tool definition for a given tool name, or `None` if the tool is unknown.""" + try: + return self.tools[name].tool_def + except KeyError: + return None + + async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + """Handle a tool call by validating the arguments, calling the tool, and handling retries. + + Args: + call: The tool call part to handle. + allow_partial: Whether to allow partial validation of the tool arguments. + """ + if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': + # Output tool calls are not traced + return await self._call_tool(call, allow_partial) + else: + return await self._call_tool_traced(call, allow_partial) + + async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + name = call.tool_name + tool = self.tools.get(name) + try: + if tool is None: + if self.tools: + msg = f'Available tools: {", ".join(f"{name!r}" for name in self.tools.keys())}' + else: + msg = 'No tools available.' + raise ModelRetry(f'Unknown tool name: {name!r}. {msg}') + + ctx = replace( + self.ctx, + tool_name=name, + tool_call_id=call.tool_call_id, + retry=self.ctx.retries.get(name, 0), + ) + + pyd_allow_partial = 'trailing-strings' if allow_partial else 'off' + validator = tool.args_validator + if isinstance(call.args, str): + args_dict = validator.validate_json(call.args or '{}', allow_partial=pyd_allow_partial) + else: + args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) + + output = await self.toolset.call_tool(name, args_dict, ctx, tool) + except (ValidationError, ModelRetry) as e: + max_retries = tool.max_retries if tool is not None else 1 + current_retry = self.ctx.retries.get(name, 0) + + if current_retry == max_retries: + raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e + else: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=call.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=call.tool_call_id, + ) + e = ToolRetryError(m) + else: + assert_never(e) + + self.ctx.retries[name] = current_retry + 1 + raise e + else: + self.ctx.retries.pop(name, None) + return output + + async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + """See .""" + span_attributes = { + 'gen_ai.tool.name': call.tool_name, + # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai + 'gen_ai.tool.call.id': call.tool_call_id, + **({'tool_arguments': call.args_as_json_str()} if self.ctx.trace_include_content else {}), + 'logfire.msg': f'running tool: {call.tool_name}', + # add the JSON schema so these attributes are formatted nicely in Logfire + 'logfire.json_schema': json.dumps( + { + 'type': 'object', + 'properties': { + **( + { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + } + if self.ctx.trace_include_content + else {} + ), + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ), + } + with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span: + try: + tool_result = await self._call_tool(call, allow_partial) + except ToolRetryError as e: + part = e.tool_retry + if self.ctx.trace_include_content and span.is_recording(): + span.set_attribute('tool_response', part.model_response()) + raise e + + if self.ctx.trace_include_content and span.is_recording(): + span.set_attribute( + 'tool_response', + tool_result + if isinstance(tool_result, str) + else _messages.tool_return_ta.dump_json(tool_result).decode(), + ) + + return tool_result + + def get_deferred_tool_calls(self, parts: Iterable[_messages.ModelResponsePart]) -> DeferredToolCalls | None: + """Get the deferred tool calls from the model response parts.""" + deferred_calls_and_defs = [ + (part, tool_def) + for part in parts + if isinstance(part, _messages.ToolCallPart) + and (tool_def := self.get_tool_def(part.tool_name)) + and tool_def.kind == 'deferred' + ] + if not deferred_calls_and_defs: + return None + + deferred_calls: list[_messages.ToolCallPart] = [] + deferred_tool_defs: dict[str, ToolDefinition] = {} + for part, tool_def in deferred_calls_and_defs: + deferred_calls.append(part) + deferred_tool_defs[part.tool_name] = tool_def + + return DeferredToolCalls(deferred_calls, deferred_tool_defs) diff --git a/pydantic_ai_slim/pydantic_ai/_utils.py b/pydantic_ai_slim/pydantic_ai/_utils.py index d3f42a7ee9..88bb30ebe3 100644 --- a/pydantic_ai_slim/pydantic_ai/_utils.py +++ b/pydantic_ai_slim/pydantic_ai/_utils.py @@ -4,8 +4,10 @@ import functools import inspect import re +import sys import time import uuid +import warnings from collections.abc import AsyncIterable, AsyncIterator, Awaitable, Iterator from contextlib import asynccontextmanager, suppress from dataclasses import dataclass, fields, is_dataclass @@ -29,7 +31,7 @@ from typing_inspection import typing_objects from typing_inspection.introspection import is_union_origin -from pydantic_graph._utils import AbstractSpan +from pydantic_graph._utils import AbstractSpan, get_event_loop from . import exceptions @@ -461,3 +463,18 @@ def get_union_args(tp: Any) -> tuple[Any, ...]: return get_args(tp) else: return () + + +# The `asyncio.Lock` `loop` argument was deprecated in 3.8 and removed in 3.10, +# but 3.9 still needs it to have the intended behavior. + +if sys.version_info < (3, 10): + + def get_async_lock() -> asyncio.Lock: # pragma: lax no cover + with warnings.catch_warnings(): + warnings.simplefilter('ignore', DeprecationWarning) + return asyncio.Lock(loop=get_event_loop()) +else: + + def get_async_lock() -> asyncio.Lock: # pragma: lax no cover + return asyncio.Lock() diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 3ff881294c..2b0aeb597e 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -4,6 +4,7 @@ import inspect import json import warnings +from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar @@ -15,7 +16,6 @@ from pydantic.json_schema import GenerateJsonSchema from typing_extensions import Literal, Never, Self, TypeIs, TypeVar, deprecated -from pydantic_ai.profiles import ModelProfile from pydantic_graph import End, Graph, GraphRun, GraphRunContext from pydantic_graph._utils import get_event_loop @@ -31,8 +31,11 @@ usage as _usage, ) from ._agent_graph import HistoryProcessor +from ._output import OutputToolset +from ._tool_manager import ToolManager from .models.instrumented import InstrumentationSettings, InstrumentedModel, instrument_model from .output import OutputDataT, OutputSpec +from .profiles import ModelProfile from .result import FinalResult, StreamedRunResult from .settings import ModelSettings, merge_model_settings from .tools import ( @@ -48,6 +51,10 @@ ToolPrepareFunc, ToolsPrepareFunc, ) +from .toolsets import AbstractToolset +from .toolsets.combined import CombinedToolset +from .toolsets.function import FunctionToolset +from .toolsets.prepared import PreparedToolset # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -153,12 +160,17 @@ class Agent(Generic[AgentDepsT, OutputDataT]): _system_prompt_dynamic_functions: dict[str, _system_prompt.SystemPromptRunner[AgentDepsT]] = dataclasses.field( repr=False ) + _function_toolset: FunctionToolset[AgentDepsT] = dataclasses.field(repr=False) + _output_toolset: OutputToolset[AgentDepsT] | None = dataclasses.field(repr=False) + _user_toolsets: Sequence[AbstractToolset[AgentDepsT]] = dataclasses.field(repr=False) _prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) - _function_tools: dict[str, Tool[AgentDepsT]] = dataclasses.field(repr=False) - _mcp_servers: Sequence[MCPServer] = dataclasses.field(repr=False) - _default_retries: int = dataclasses.field(repr=False) + _prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = dataclasses.field(repr=False) _max_result_retries: int = dataclasses.field(repr=False) + _enter_lock: Lock = dataclasses.field(repr=False) + _entered_count: int = dataclasses.field(repr=False) + _exit_stack: AsyncExitStack | None = dataclasses.field(repr=False) + @overload def __init__( self, @@ -177,7 +189,8 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -186,7 +199,7 @@ def __init__( @overload @deprecated( - '`result_type`, `result_tool_name`, `result_tool_description` & `result_retries` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' + '`result_type`, `result_tool_name` & `result_tool_description` are deprecated, use `output_type` instead. `result_retries` is deprecated, use `output_retries` instead.' ) def __init__( self, @@ -207,6 +220,36 @@ def __init__( result_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + defer_model_check: bool = False, + end_strategy: EndStrategy = 'early', + instrument: InstrumentationSettings | bool | None = None, + history_processors: Sequence[HistoryProcessor[AgentDepsT]] | None = None, + ) -> None: ... + + @overload + @deprecated('`mcp_servers` is deprecated, use `toolsets` instead.') + def __init__( + self, + model: models.Model | models.KnownModelName | str | None = None, + *, + result_type: type[OutputDataT] = str, + instructions: str + | _system_prompt.SystemPromptFunc[AgentDepsT] + | Sequence[str | _system_prompt.SystemPromptFunc[AgentDepsT]] + | None = None, + system_prompt: str | Sequence[str] = (), + deps_type: type[AgentDepsT] = NoneType, + name: str | None = None, + model_settings: ModelSettings | None = None, + retries: int = 1, + result_tool_name: str = _output.DEFAULT_OUTPUT_TOOL_NAME, + result_tool_description: str | None = None, + result_retries: int | None = None, + tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), + prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, mcp_servers: Sequence[MCPServer] = (), defer_model_check: bool = False, end_strategy: EndStrategy = 'early', @@ -232,7 +275,8 @@ def __init__( output_retries: int | None = None, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = (), prepare_tools: ToolsPrepareFunc[AgentDepsT] | None = None, - mcp_servers: Sequence[MCPServer] = (), + prepare_output_tools: ToolsPrepareFunc[AgentDepsT] | None = None, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, defer_model_check: bool = False, end_strategy: EndStrategy = 'early', instrument: InstrumentationSettings | bool | None = None, @@ -258,14 +302,16 @@ def __init__( when the agent is first run. model_settings: Optional model request settings to use for this agent's runs, by default. retries: The default number of retries to allow before raising an error. - output_retries: The maximum number of retries to allow for result validation, defaults to `retries`. + output_retries: The maximum number of retries to allow for output validation, defaults to `retries`. tools: Tools to register with the agent, you can also register tools via the decorators [`@agent.tool`][pydantic_ai.Agent.tool] and [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain]. - prepare_tools: custom method to prepare the tool definition of all tools for each step. + prepare_tools: Custom function to prepare the tool definition of all tools for each step, except output tools. This is useful if you want to customize the definition of multiple tools or you want to register a subset of tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] - mcp_servers: MCP servers to register with the agent. You should register a [`MCPServer`][pydantic_ai.mcp.MCPServer] - for each server you want the agent to connect to. + prepare_output_tools: Custom function to prepare the tool definition of all output tools for each step. + This is useful if you want to customize the definition of multiple output tools or you want to register + a subset of output tools for a given step. See [`ToolsPrepareFunc`][pydantic_ai.tools.ToolsPrepareFunc] + toolsets: Toolsets to register with the agent, including MCP servers. defer_model_check: by default, if you provide a [named][pydantic_ai.models.KnownModelName] model, it's evaluated to create a [`Model`][pydantic_ai.models.Model] instance immediately, which checks for the necessary environment variables. Set this to `false` @@ -329,10 +375,17 @@ def __init__( ) output_retries = result_retries + if mcp_servers := _deprecated_kwargs.pop('mcp_servers', None): + if toolsets is not None: # pragma: no cover + raise TypeError('`mcp_servers` and `toolsets` cannot be set at the same time.') + warnings.warn('`mcp_servers` is deprecated, use `toolsets` instead', DeprecationWarning) + toolsets = mcp_servers + + _utils.validate_empty_kwargs(_deprecated_kwargs) + default_output_mode = ( self.model.profile.default_structured_output_mode if isinstance(self.model, models.Model) else None ) - _utils.validate_empty_kwargs(_deprecated_kwargs) self._output_schema = _output.OutputSchema[OutputDataT].build( output_type, @@ -357,21 +410,28 @@ def __init__( self._system_prompt_functions = [] self._system_prompt_dynamic_functions = {} - self._function_tools = {} - - self._default_retries = retries self._max_result_retries = output_retries if output_retries is not None else retries - self._mcp_servers = mcp_servers self._prepare_tools = prepare_tools + self._prepare_output_tools = prepare_output_tools + + self._output_toolset = self._output_schema.toolset + if self._output_toolset: + self._output_toolset.max_retries = self._max_result_retries + + self._function_toolset = FunctionToolset(tools, max_retries=retries) + self._user_toolsets = toolsets or () + self.history_processors = history_processors or [] - for tool in tools: - if isinstance(tool, Tool): - self._register_tool(tool) - else: - self._register_tool(Tool(tool)) self._override_deps: ContextVar[_utils.Option[AgentDepsT]] = ContextVar('_override_deps', default=None) self._override_model: ContextVar[_utils.Option[models.Model]] = ContextVar('_override_model', default=None) + self._override_toolsets: ContextVar[_utils.Option[Sequence[AbstractToolset[AgentDepsT]]]] = ContextVar( + '_override_toolsets', default=None + ) + + self._enter_lock = _utils.get_async_lock() + self._entered_count = 0 + self._exit_stack = None @staticmethod def instrument_all(instrument: InstrumentationSettings | bool = True) -> None: @@ -391,6 +451,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -406,6 +467,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -422,6 +484,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... async def run( @@ -436,6 +499,7 @@ async def run( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Run the agent with a user prompt in async mode. @@ -466,6 +530,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -490,6 +555,7 @@ async def main(): model_settings=model_settings, usage_limits=usage_limits, usage=usage, + toolsets=toolsets, ) as agent_run: async for _ in agent_run: pass @@ -510,6 +576,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, OutputDataT]]: ... @@ -526,6 +593,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, RunOutputDataT]]: ... @@ -543,6 +611,7 @@ def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[AgentRun[AgentDepsT, Any]]: ... @asynccontextmanager @@ -558,6 +627,7 @@ async def iter( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[AgentRun[AgentDepsT, Any]]: """A contextmanager which can be used to iterate over the agent graph's nodes as they are executed. @@ -632,6 +702,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -655,6 +726,18 @@ async def main(): output_type_ = output_type or self.output_type + # We consider it a user error if a user tries to restrict the result type while having an output validator that + # may change the result type from the restricted type to something else. Therefore, we consider the following + # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. + output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) + + output_toolset = self._output_toolset + if output_schema != self._output_schema or output_validators: + output_toolset = cast(OutputToolset[AgentDepsT], output_schema.toolset) + if output_toolset: + output_toolset.max_retries = self._max_result_retries + output_toolset.output_validators = output_validators + # Build the graph graph: Graph[_agent_graph.GraphAgentState, _agent_graph.GraphAgentDeps[AgentDepsT, Any], FinalResult[Any]] = ( _agent_graph.build_agent_graph(self.name, self._deps_type, output_type_) @@ -669,22 +752,32 @@ async def main(): run_step=0, ) - # We consider it a user error if a user tries to restrict the result type while having an output validator that - # may change the result type from the restricted type to something else. Therefore, we consider the following - # typecast reasonable, even though it is possible to violate it with otherwise-type-checked code. - output_validators = cast(list[_output.OutputValidator[AgentDepsT, RunOutputDataT]], self._output_validators) - - # Merge model settings in order of precedence: run > agent > model - merged_settings = merge_model_settings(model_used.settings, self.model_settings) - model_settings = merge_model_settings(merged_settings, model_settings) - usage_limits = usage_limits or _usage.UsageLimits() - if isinstance(model_used, InstrumentedModel): instrumentation_settings = model_used.instrumentation_settings tracer = model_used.instrumentation_settings.tracer else: instrumentation_settings = None tracer = NoOpTracer() + + run_context = RunContext[AgentDepsT]( + deps=deps, + model=model_used, + usage=usage, + prompt=user_prompt, + messages=state.message_history, + tracer=tracer, + trace_include_content=instrumentation_settings is not None and instrumentation_settings.include_content, + run_step=state.run_step, + ) + + toolset = self._get_toolset(output_toolset=output_toolset, additional_toolsets=toolsets) + # This will raise errors for any name conflicts + run_toolset = await ToolManager[AgentDepsT].build(toolset, run_context) + + # Merge model settings in order of precedence: run > agent > model + merged_settings = merge_model_settings(model_used.settings, self.model_settings) + model_settings = merge_model_settings(merged_settings, model_settings) + usage_limits = usage_limits or _usage.UsageLimits() agent_name = self.name or 'agent' run_span = tracer.start_span( 'agent run', @@ -711,10 +804,6 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: return None return '\n\n'.join(parts).strip() - # Copy the function tools so that retry state is agent-run-specific - # Note that the retry count is reset to 0 when this happens due to the `default=0` and `init=False`. - run_function_tools = {k: dataclasses.replace(v) for k, v in self._function_tools.items()} - graph_deps = _agent_graph.GraphAgentDeps[AgentDepsT, RunOutputDataT]( user_deps=deps, prompt=user_prompt, @@ -727,11 +816,8 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: output_schema=output_schema, output_validators=output_validators, history_processors=self.history_processors, - function_tools=run_function_tools, - mcp_servers=self._mcp_servers, - default_retries=self._default_retries, + tool_manager=run_toolset, tracer=tracer, - prepare_tools=self._prepare_tools, get_instructions=get_instructions, instrumentation_settings=instrumentation_settings, ) @@ -801,6 +887,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[OutputDataT]: ... @overload @@ -816,6 +903,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... @overload @@ -832,6 +920,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AgentRunResult[RunOutputDataT]: ... def run_sync( @@ -846,6 +935,7 @@ def run_sync( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AgentRunResult[Any]: """Synchronously run the agent with a user prompt. @@ -875,6 +965,7 @@ def run_sync( usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -901,6 +992,7 @@ def run_sync( usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, ) ) @@ -916,6 +1008,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, OutputDataT]]: ... @overload @@ -931,6 +1024,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @overload @@ -947,6 +1041,7 @@ def run_stream( usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, ) -> AbstractAsyncContextManager[result.StreamedRunResult[AgentDepsT, RunOutputDataT]]: ... @asynccontextmanager @@ -962,6 +1057,7 @@ async def run_stream( # noqa C901 usage_limits: _usage.UsageLimits | None = None, usage: _usage.Usage | None = None, infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, **_deprecated_kwargs: Never, ) -> AsyncIterator[result.StreamedRunResult[AgentDepsT, Any]]: """Run the agent with a user prompt in async mode, returning a streamed response. @@ -989,6 +1085,7 @@ async def main(): usage_limits: Optional limits on model request count or token usage. usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional additional toolsets for this run. Returns: The result of the run. @@ -1019,6 +1116,7 @@ async def main(): usage_limits=usage_limits, usage=usage, infer_name=False, + toolsets=toolsets, ) as agent_run: first_node = agent_run.next_node # start with the first node assert isinstance(first_node, _agent_graph.UserPromptNode) # the first node should be a user prompt node @@ -1039,15 +1137,17 @@ async def stream_to_final( output_schema, _output.TextOutputSchema ): return FinalResult(s, None, None) - elif isinstance(new_part, _messages.ToolCallPart) and isinstance( - output_schema, _output.ToolOutputSchema - ): # pragma: no branch - for call, _ in output_schema.find_tool([new_part]): - return FinalResult(s, call.tool_name, call.tool_call_id) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := graph_ctx.deps.tool_manager.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return FinalResult(s, new_part.tool_name, new_part.tool_call_id) + elif tool_def.kind == 'deferred': + return FinalResult(s, None, None) return None - final_result_details = await stream_to_final(streamed_response) - if final_result_details is not None: + final_result = await stream_to_final(streamed_response) + if final_result is not None: if yielded: raise exceptions.AgentRunError('Agent run produced final results') # pragma: no cover yielded = True @@ -1068,17 +1168,13 @@ async def on_complete() -> None: parts: list[_messages.ModelRequestPart] = [] async for _event in _agent_graph.process_function_tools( + graph_ctx.deps.tool_manager, tool_calls, - final_result_details.tool_name, - final_result_details.tool_call_id, + final_result, graph_ctx, parts, ): pass - # TODO: Should we do something here related to the retry count? - # Maybe we should move the incrementing of the retry count to where we actually make a request? - # if any(isinstance(part, _messages.RetryPromptPart) for part in parts): - # ctx.state.increment_retries(ctx.deps.max_result_retries) if parts: messages.append(_messages.ModelRequest(parts)) @@ -1089,10 +1185,10 @@ async def on_complete() -> None: streamed_response, graph_ctx.deps.output_schema, _agent_graph.build_run_context(graph_ctx), - _output.build_trace_context(graph_ctx), graph_ctx.deps.output_validators, - final_result_details.tool_name, + final_result.tool_name, on_complete, + graph_ctx.deps.tool_manager, ) break next_node = await agent_run.next(node) @@ -1111,8 +1207,9 @@ def override( *, deps: AgentDepsT | _utils.Unset = _utils.UNSET, model: models.Model | models.KnownModelName | str | _utils.Unset = _utils.UNSET, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | _utils.Unset = _utils.UNSET, ) -> Iterator[None]: - """Context manager to temporarily override agent dependencies and model. + """Context manager to temporarily override agent dependencies, model, or toolsets. This is particularly useful when testing. You can find an example of this [here](../testing.md#overriding-model-via-pytest-fixtures). @@ -1120,6 +1217,7 @@ def override( Args: deps: The dependencies to use instead of the dependencies passed to the agent run. model: The model to use instead of the model passed to the agent run. + toolsets: The toolsets to use instead of the toolsets passed to the agent constructor and agent run. """ if _utils.is_set(deps): deps_token = self._override_deps.set(_utils.Some(deps)) @@ -1131,6 +1229,11 @@ def override( else: model_token = None + if _utils.is_set(toolsets): + toolsets_token = self._override_toolsets.set(_utils.Some(toolsets)) + else: + toolsets_token = None + try: yield finally: @@ -1138,6 +1241,8 @@ def override( self._override_deps.reset(deps_token) if model_token is not None: self._override_model.reset(model_token) + if toolsets_token is not None: + self._override_toolsets.reset(toolsets_token) @overload def instructions( @@ -1423,30 +1528,13 @@ async def spam(ctx: RunContext[str], y: float) -> float: strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ - if func is None: - def tool_decorator( - func_: ToolFuncContext[AgentDepsT, ToolParams], - ) -> ToolFuncContext[AgentDepsT, ToolParams]: - # noinspection PyTypeChecker - self._register_function( - func_, - True, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - - return tool_decorator - else: + def tool_decorator( + func_: ToolFuncContext[AgentDepsT, ToolParams], + ) -> ToolFuncContext[AgentDepsT, ToolParams]: # noinspection PyTypeChecker - self._register_function( - func, + self._function_toolset.add_function( + func_, True, name, retries, @@ -1456,7 +1544,9 @@ def tool_decorator( schema_generator, strict, ) - return func + return func_ + + return tool_decorator if func is None else tool_decorator(func) @overload def tool_plain(self, func: ToolFuncPlain[ToolParams], /) -> ToolFuncPlain[ToolParams]: ... @@ -1532,27 +1622,11 @@ async def spam(ctx: RunContext[str]) -> float: strict: Whether to enforce JSON schema compliance (only affects OpenAI). See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. """ - if func is None: - - def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: - # noinspection PyTypeChecker - self._register_function( - func_, - False, - name, - retries, - prepare, - docstring_format, - require_parameter_descriptions, - schema_generator, - strict, - ) - return func_ - return tool_decorator - else: - self._register_function( - func, + def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams]: + # noinspection PyTypeChecker + self._function_toolset.add_function( + func_, False, name, retries, @@ -1562,48 +1636,9 @@ def tool_decorator(func_: ToolFuncPlain[ToolParams]) -> ToolFuncPlain[ToolParams schema_generator, strict, ) - return func - - def _register_function( - self, - func: ToolFuncEither[AgentDepsT, ToolParams], - takes_ctx: bool, - name: str | None, - retries: int | None, - prepare: ToolPrepareFunc[AgentDepsT] | None, - docstring_format: DocstringFormat, - require_parameter_descriptions: bool, - schema_generator: type[GenerateJsonSchema], - strict: bool | None, - ) -> None: - """Private utility to register a function as a tool.""" - retries_ = retries if retries is not None else self._default_retries - tool = Tool[AgentDepsT]( - func, - takes_ctx=takes_ctx, - name=name, - max_retries=retries_, - prepare=prepare, - docstring_format=docstring_format, - require_parameter_descriptions=require_parameter_descriptions, - schema_generator=schema_generator, - strict=strict, - ) - self._register_tool(tool) - - def _register_tool(self, tool: Tool[AgentDepsT]) -> None: - """Private utility to register a tool instance.""" - if tool.max_retries is None: - # noinspection PyTypeChecker - tool = dataclasses.replace(tool, max_retries=self._default_retries) + return func_ - if tool.name in self._function_tools: - raise exceptions.UserError(f'Tool name conflicts with existing tool: {tool.name!r}') - - if tool.name in self._output_schema.tools: - raise exceptions.UserError(f'Tool name conflicts with output tool name: {tool.name!r}') - - self._function_tools[tool.name] = tool + return tool_decorator if func is None else tool_decorator(func) def _get_model(self, model: models.Model | models.KnownModelName | str | None) -> models.Model: """Create a model configured for this agent. @@ -1649,6 +1684,37 @@ def _get_deps(self: Agent[T, OutputDataT], deps: T) -> T: else: return deps + def _get_toolset( + self, + output_toolset: AbstractToolset[AgentDepsT] | None | _utils.Unset = _utils.UNSET, + additional_toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AbstractToolset[AgentDepsT]: + """Get the complete toolset. + + Args: + output_toolset: The output toolset to use instead of the one built at agent construction time. + additional_toolsets: Additional toolsets to add. + """ + if some_user_toolsets := self._override_toolsets.get(): + user_toolsets = some_user_toolsets.value + elif additional_toolsets is not None: + user_toolsets = [*self._user_toolsets, *additional_toolsets] + else: + user_toolsets = self._user_toolsets + + all_toolsets = [self._function_toolset, *user_toolsets] + + if self._prepare_tools: + all_toolsets = [PreparedToolset(CombinedToolset(all_toolsets), self._prepare_tools)] + + output_toolset = output_toolset if _utils.is_set(output_toolset) else self._output_toolset + if output_toolset is not None: + if self._prepare_output_tools: + output_toolset = PreparedToolset(output_toolset, self._prepare_output_tools) + all_toolsets = [output_toolset, *all_toolsets] + + return CombinedToolset(all_toolsets) + def _infer_name(self, function_frame: FrameType | None) -> None: """Infer the agent name from the call frame. @@ -1734,28 +1800,68 @@ def is_end_node( """ return isinstance(node, End) + async def __aenter__(self) -> Self: + """Enter the agent context. + + This will start all [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] registered as `toolsets` so they are ready to be used. + + This is a no-op if the agent has already been entered. + """ + async with self._enter_lock: + if self._entered_count == 0: + self._exit_stack = AsyncExitStack() + toolset = self._get_toolset() + await self._exit_stack.enter_async_context(toolset) + self._entered_count += 1 + return self + + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._entered_count -= 1 + if self._entered_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + + def set_mcp_sampling_model(self, model: models.Model | models.KnownModelName | str | None = None) -> None: + """Set the sampling model on all MCP servers registered with the agent. + + If no sampling model is provided, the agent's model will be used. + """ + try: + sampling_model = models.infer_model(model) if model else self._get_model(None) + except exceptions.UserError as e: + raise exceptions.UserError('No sampling model provided and no model set on the agent.') from e + + from .mcp import MCPServer + + def _set_sampling_model(toolset: AbstractToolset[AgentDepsT]) -> None: + if isinstance(toolset, MCPServer): + toolset.sampling_model = sampling_model + + self._get_toolset().apply(_set_sampling_model) + @asynccontextmanager + @deprecated( + '`run_mcp_servers` is deprecated, use `async with agent:` instead. If you need to set a sampling model on all MCP servers, use `agent.set_mcp_sampling_model()`.' + ) async def run_mcp_servers( self, model: models.Model | models.KnownModelName | str | None = None ) -> AsyncIterator[None]: """Run [`MCPServerStdio`s][pydantic_ai.mcp.MCPServerStdio] so they can be used by the agent. + Deprecated: use [`async with agent`][pydantic_ai.agent.Agent.__aenter__] instead. + If you need to set a sampling model on all MCP servers, use [`agent.set_mcp_sampling_model()`][pydantic_ai.agent.Agent.set_mcp_sampling_model]. + Returns: a context manager to start and shutdown the servers. """ try: - sampling_model: models.Model | None = self._get_model(model) - except exceptions.UserError: # pragma: no cover - sampling_model = None + self.set_mcp_sampling_model(model) + except exceptions.UserError: + if model is not None: + raise - exit_stack = AsyncExitStack() - try: - for mcp_server in self._mcp_servers: - if sampling_model is not None: # pragma: no branch - mcp_server.sampling_model = sampling_model - await exit_stack.enter_async_context(mcp_server) + async with self: yield - finally: - await exit_stack.aclose() def to_a2a( self, diff --git a/pydantic_ai_slim/pydantic_ai/exceptions.py b/pydantic_ai_slim/pydantic_ai/exceptions.py index 3f57faaf8d..344ab94daf 100644 --- a/pydantic_ai_slim/pydantic_ai/exceptions.py +++ b/pydantic_ai_slim/pydantic_ai/exceptions.py @@ -2,12 +2,16 @@ import json import sys +from typing import TYPE_CHECKING if sys.version_info < (3, 11): from exceptiongroup import ExceptionGroup else: ExceptionGroup = ExceptionGroup +if TYPE_CHECKING: + from .messages import RetryPromptPart + __all__ = ( 'ModelRetry', 'UserError', @@ -113,3 +117,11 @@ def __init__(self, status_code: int, model_name: str, body: object | None = None class FallbackExceptionGroup(ExceptionGroup): """A group of exceptions that can be raised when all fallback models fail.""" + + +class ToolRetryError(Exception): + """Exception used to signal a `ToolRetry` message should be returned to the LLM.""" + + def __init__(self, tool_retry: RetryPromptPart): + self.tool_retry = tool_retry + super().__init__() diff --git a/pydantic_ai_slim/pydantic_ai/ext/aci.py b/pydantic_ai_slim/pydantic_ai/ext/aci.py index 5e5dc49366..6cd43402a1 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/aci.py +++ b/pydantic_ai_slim/pydantic_ai/ext/aci.py @@ -4,11 +4,13 @@ except ImportError as _import_error: raise ImportError('Please install `aci-sdk` to use ACI.dev tools') from _import_error +from collections.abc import Sequence from typing import Any from aci import ACI -from pydantic_ai import Tool +from pydantic_ai.tools import Tool +from pydantic_ai.toolsets.function import FunctionToolset def _clean_schema(schema): @@ -22,10 +24,10 @@ def _clean_schema(schema): def tool_from_aci(aci_function: str, linked_account_owner_id: str) -> Tool: - """Creates a Pydantic AI tool proxy from an ACI function. + """Creates a Pydantic AI tool proxy from an ACI.dev function. Args: - aci_function: The ACI function to wrao. + aci_function: The ACI.dev function to wrap. linked_account_owner_id: The ACI user ID to execute the function on behalf of. Returns: @@ -64,3 +66,10 @@ def implementation(*args: Any, **kwargs: Any) -> str: description=function_description, json_schema=json_schema, ) + + +class ACIToolset(FunctionToolset): + """A toolset that wraps ACI.dev tools.""" + + def __init__(self, aci_functions: Sequence[str], linked_account_owner_id: str): + super().__init__([tool_from_aci(aci_function, linked_account_owner_id) for aci_function in aci_functions]) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index 60db763f9d..3fb4079386 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -3,6 +3,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai.tools import Tool +from pydantic_ai.toolsets.function import FunctionToolset class LangChainTool(Protocol): @@ -23,7 +24,7 @@ def description(self) -> str: ... def run(self, *args: Any, **kwargs: Any) -> str: ... -__all__ = ('tool_from_langchain',) +__all__ = ('tool_from_langchain', 'LangChainToolset') def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: @@ -59,3 +60,10 @@ def proxy(*args: Any, **kwargs: Any) -> str: description=function_description, json_schema=schema, ) + + +class LangChainToolset(FunctionToolset): + """A toolset that wraps LangChain tools.""" + + def __init__(self, tools: list[LangChainTool]): + super().__init__([tool_from_langchain(tool) for tool in tools]) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index e1fc10f29d..2ca7950b3e 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -3,11 +3,11 @@ import base64 import functools from abc import ABC, abstractmethod +from asyncio import Lock from collections.abc import AsyncIterator, Awaitable, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager -from dataclasses import dataclass +from dataclasses import dataclass, field, replace from pathlib import Path -from types import TracebackType from typing import Any, Callable import anyio @@ -16,6 +16,11 @@ from anyio.streams.memory import MemoryObjectReceiveStream, MemoryObjectSendStream from typing_extensions import Self, assert_never, deprecated +from pydantic_ai._run_context import RunContext +from pydantic_ai.tools import ToolDefinition + +from .toolsets.abstract import AbstractToolset, ToolsetTool + try: from mcp import types as mcp_types from mcp.client.session import ClientSession, LoggingFnT @@ -32,12 +37,18 @@ ) from _import_error # after mcp imports so any import error maps to this file, not _mcp.py -from . import _mcp, exceptions, messages, models, tools +from . import _mcp, exceptions, messages, models __all__ = 'MCPServer', 'MCPServerStdio', 'MCPServerHTTP', 'MCPServerSSE', 'MCPServerStreamableHTTP' +TOOL_SCHEMA_VALIDATOR = pydantic_core.SchemaValidator( + schema=pydantic_core.core_schema.dict_schema( + pydantic_core.core_schema.str_schema(), pydantic_core.core_schema.any_schema() + ) +) + -class MCPServer(ABC): +class MCPServer(AbstractToolset[Any], ABC): """Base class for attaching agents to MCP servers. See for more information. @@ -50,15 +61,22 @@ class MCPServer(ABC): timeout: float = 5 process_tool_call: ProcessToolCallback | None = None allow_sampling: bool = True + max_retries: int = 1 + sampling_model: models.Model | None = None # } end of "abstract fields" - _running_count: int = 0 + _enter_lock: Lock = field(compare=False) + _running_count: int + _exit_stack: AsyncExitStack | None _client: ClientSession _read_stream: MemoryObjectReceiveStream[SessionMessage | Exception] _write_stream: MemoryObjectSendStream[SessionMessage] - _exit_stack: AsyncExitStack - sampling_model: models.Model | None = None + + def __post_init__(self): + self._enter_lock = Lock() + self._running_count = 0 + self._exit_stack = None @abstractmethod @asynccontextmanager @@ -74,47 +92,36 @@ async def client_streams( raise NotImplementedError('MCP Server subclasses must implement this method.') yield - def get_prefixed_tool_name(self, tool_name: str) -> str: - """Get the tool name with prefix if `tool_prefix` is set.""" - return f'{self.tool_prefix}_{tool_name}' if self.tool_prefix else tool_name - - def get_unprefixed_tool_name(self, tool_name: str) -> str: - """Get original tool name without prefix for calling tools.""" - return tool_name.removeprefix(f'{self.tool_prefix}_') if self.tool_prefix else tool_name + @property + def name(self) -> str: + return repr(self) @property - def is_running(self) -> bool: - """Check if the MCP server is running.""" - return bool(self._running_count) + def tool_name_conflict_hint(self) -> str: + return 'Consider setting `tool_prefix` to avoid name conflicts.' - async def list_tools(self) -> list[tools.ToolDefinition]: + async def list_tools(self) -> list[mcp_types.Tool]: """Retrieve tools that are currently active on the server. Note: - We don't cache tools as they might change. - We also don't subscribe to the server to avoid complexity. """ - mcp_tools = await self._client.list_tools() - return [ - tools.ToolDefinition( - name=self.get_prefixed_tool_name(tool.name), - description=tool.description, - parameters_json_schema=tool.inputSchema, - ) - for tool in mcp_tools.tools - ] + async with self: # Ensure server is running + result = await self._client.list_tools() + return result.tools - async def call_tool( + async def direct_call_tool( self, - tool_name: str, - arguments: dict[str, Any], + name: str, + args: dict[str, Any], metadata: dict[str, Any] | None = None, ) -> ToolResult: """Call a tool on the server. Args: - tool_name: The name of the tool to call. - arguments: The arguments to pass to the tool. + name: The name of the tool to call. + args: The arguments to pass to the tool. metadata: Request-level metadata (optional) Returns: @@ -123,23 +130,23 @@ async def call_tool( Raises: ModelRetry: If the tool call fails. """ - try: - # meta param is not provided by session yet, so build and can send_request directly. - result = await self._client.send_request( - mcp_types.ClientRequest( - mcp_types.CallToolRequest( - method='tools/call', - params=mcp_types.CallToolRequestParams( - name=self.get_unprefixed_tool_name(tool_name), - arguments=arguments, - _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, - ), - ) - ), - mcp_types.CallToolResult, - ) - except McpError as e: - raise exceptions.ModelRetry(e.error.message) + async with self: # Ensure server is running + try: + result = await self._client.send_request( + mcp_types.ClientRequest( + mcp_types.CallToolRequest( + method='tools/call', + params=mcp_types.CallToolRequestParams( + name=name, + arguments=args, + _meta=mcp_types.RequestParams.Meta(**metadata) if metadata else None, + ), + ) + ), + mcp_types.CallToolResult, + ) + except McpError as e: + raise exceptions.ModelRetry(e.error.message) content = [self._map_tool_result_part(part) for part in result.content] @@ -149,36 +156,80 @@ async def call_tool( else: return content[0] if len(content) == 1 else content - async def __aenter__(self) -> Self: - if self._running_count == 0: - self._exit_stack = AsyncExitStack() - - self._read_stream, self._write_stream = await self._exit_stack.enter_async_context(self.client_streams()) - client = ClientSession( - read_stream=self._read_stream, - write_stream=self._write_stream, - sampling_callback=self._sampling_callback if self.allow_sampling else None, - logging_callback=self.log_handler, + async def call_tool( + self, + name: str, + tool_args: dict[str, Any], + ctx: RunContext[Any], + tool: ToolsetTool[Any], + ) -> ToolResult: + if self.tool_prefix: + name = name.removeprefix(f'{self.tool_prefix}_') + ctx = replace(ctx, tool_name=name) + + if self.process_tool_call is not None: + return await self.process_tool_call(ctx, self.direct_call_tool, name, tool_args) + else: + return await self.direct_call_tool(name, tool_args) + + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + return { + name: ToolsetTool( + toolset=self, + tool_def=ToolDefinition( + name=name, + description=mcp_tool.description, + parameters_json_schema=mcp_tool.inputSchema, + ), + max_retries=self.max_retries, + args_validator=TOOL_SCHEMA_VALIDATOR, ) - self._client = await self._exit_stack.enter_async_context(client) + for mcp_tool in await self.list_tools() + if (name := f'{self.tool_prefix}_{mcp_tool.name}' if self.tool_prefix else mcp_tool.name) + } + + async def __aenter__(self) -> Self: + """Enter the MCP server context. + + This will initialize the connection to the server. + If this server is an [`MCPServerStdio`][pydantic_ai.mcp.MCPServerStdio], the server will first be started as a subprocess. - with anyio.fail_after(self.timeout): - await self._client.initialize() + This is a no-op if the MCP server has already been entered. + """ + async with self._enter_lock: + if self._running_count == 0: + self._exit_stack = AsyncExitStack() + + self._read_stream, self._write_stream = await self._exit_stack.enter_async_context( + self.client_streams() + ) + client = ClientSession( + read_stream=self._read_stream, + write_stream=self._write_stream, + sampling_callback=self._sampling_callback if self.allow_sampling else None, + logging_callback=self.log_handler, + ) + self._client = await self._exit_stack.enter_async_context(client) + + with anyio.fail_after(self.timeout): + await self._client.initialize() - if log_level := self.log_level: - await self._client.set_logging_level(log_level) - self._running_count += 1 + if log_level := self.log_level: + await self._client.set_logging_level(log_level) + self._running_count += 1 return self - async def __aexit__( - self, - exc_type: type[BaseException] | None, - exc_value: BaseException | None, - traceback: TracebackType | None, - ) -> bool | None: - self._running_count -= 1 - if self._running_count <= 0: - await self._exit_stack.aclose() + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._running_count -= 1 + if self._running_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + + @property + def is_running(self) -> bool: + """Check if the MCP server is running.""" + return bool(self._running_count) async def _sampling_callback( self, context: RequestContext[ClientSession, Any], params: mcp_types.CreateMessageRequestParams @@ -271,10 +322,10 @@ class MCPServerStdio(MCPServer): 'stdio', ] ) - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -327,6 +378,12 @@ async def main(): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + + sampling_model: models.Model | None = None + """The model to use for sampling.""" + @asynccontextmanager async def client_streams( self, @@ -422,6 +479,12 @@ class _MCPServerHTTP(MCPServer): allow_sampling: bool = True """Whether to allow MCP sampling through this client.""" + max_retries: int = 1 + """The maximum number of times to retry a tool call.""" + + sampling_model: models.Model | None = None + """The model to use for sampling.""" + @property @abstractmethod def _transport_client( @@ -503,10 +566,10 @@ class MCPServerSSE(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerSSE server = MCPServerSSE('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -537,10 +600,10 @@ class MCPServerHTTP(MCPServerSSE): from pydantic_ai.mcp import MCPServerHTTP server = MCPServerHTTP('http://localhost:3001/sse') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` @@ -566,10 +629,10 @@ class MCPServerStreamableHTTP(_MCPServerHTTP): from pydantic_ai.mcp import MCPServerStreamableHTTP server = MCPServerStreamableHTTP('http://localhost:8000/mcp') # (1)! - agent = Agent('openai:gpt-4o', mcp_servers=[server]) + agent = Agent('openai:gpt-4o', toolsets=[server]) async def main(): - async with agent.run_mcp_servers(): # (2)! + async with agent: # (2)! ... ``` """ @@ -586,14 +649,14 @@ def _transport_client(self): | list[Any] | Sequence[str | messages.BinaryContent | dict[str, Any] | list[Any]] ) -"""The result type of a tool call.""" +"""The result type of an MCP tool call.""" CallToolFunc = Callable[[str, dict[str, Any], dict[str, Any] | None], Awaitable[ToolResult]] """A function type that represents a tool call.""" ProcessToolCallback = Callable[ [ - tools.RunContext[Any], + RunContext[Any], CallToolFunc, str, dict[str, Any], diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 9dc7d2ef6b..9c41b535db 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -10,7 +10,8 @@ from typing_extensions import TypeAliasType, TypeVar from . import _utils -from .tools import RunContext +from .messages import ToolCallPart +from .tools import RunContext, ToolDefinition __all__ = ( # classes @@ -330,15 +331,17 @@ def __get_pydantic_json_schema__( return _StructuredDict +_OutputSpecItem = TypeAliasType( + '_OutputSpecItem', + Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], NativeOutput[T_co], PromptedOutput[T_co], TextOutput[T_co]], + type_params=(T_co,), +) + OutputSpec = TypeAliasType( 'OutputSpec', Union[ - OutputTypeOrFunction[T_co], - ToolOutput[T_co], - NativeOutput[T_co], - PromptedOutput[T_co], - TextOutput[T_co], - Sequence[Union[OutputTypeOrFunction[T_co], ToolOutput[T_co], TextOutput[T_co]]], + _OutputSpecItem[T_co], + Sequence['OutputSpec[T_co]'], ], type_params=(T_co,), ) @@ -354,3 +357,11 @@ def __get_pydantic_json_schema__( See [output docs](../output.md) for more information. """ + + +@dataclass +class DeferredToolCalls: + """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.""" + + tool_calls: list[ToolCallPart] + tool_defs: dict[str, ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index f700482662..163189ac0b 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -5,11 +5,13 @@ from copy import copy from dataclasses import dataclass, field from datetime import datetime -from typing import Generic +from typing import Generic, cast from pydantic import ValidationError from typing_extensions import TypeVar, deprecated, overload +from pydantic_ai._tool_manager import ToolManager + from . import _utils, exceptions, messages as _messages, models from ._output import ( OutputDataT_inv, @@ -19,7 +21,6 @@ PlainTextOutputSchema, TextOutputSchema, ToolOutputSchema, - TraceContext, ) from ._run_context import AgentDepsT, RunContext from .messages import AgentStreamEvent, FinalResultEvent @@ -47,8 +48,8 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_schema: OutputSchema[OutputDataT] _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] - _trace_ctx: TraceContext _usage_limits: UsageLimits | None + _toolset: ToolManager[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _final_result_event: FinalResultEvent | None = field(default=None, init=False) @@ -97,37 +98,40 @@ async def _validate_response( self, message: _messages.ModelResponse, output_tool_name: str | None, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, - self._run_ctx, - self._trace_ctx, - allow_partial=allow_partial, - wrap_validation_errors=False, - ) + return await self._toolset.handle_call(tool_call, allow_partial=allow_partial) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( # pragma: no cover + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, self._run_ctx) + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) - return result_data - def __aiter__(self) -> AsyncIterator[AgentStreamEvent]: """Stream [`AgentStreamEvent`][pydantic_ai.messages.AgentStreamEvent]s. @@ -145,13 +149,19 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. """Return an appropriate FinalResultEvent if `e` corresponds to a part that will produce a final result.""" if isinstance(e, _messages.PartStartEvent): new_part = e.part - if isinstance(new_part, _messages.ToolCallPart) and isinstance(output_schema, ToolOutputSchema): - for call, _ in output_schema.find_tool([new_part]): # pragma: no branch - return _messages.FinalResultEvent(tool_name=call.tool_name, tool_call_id=call.tool_call_id) - elif isinstance(new_part, _messages.TextPart) and isinstance( + if isinstance(new_part, _messages.TextPart) and isinstance( output_schema, TextOutputSchema ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) + elif isinstance(new_part, _messages.ToolCallPart) and ( + tool_def := self._toolset.get_tool_def(new_part.tool_name) + ): + if tool_def.kind == 'output': + return _messages.FinalResultEvent( + tool_name=new_part.tool_name, tool_call_id=new_part.tool_call_id + ) + elif tool_def.kind == 'deferred': + return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) usage_checking_stream = _get_usage_checking_stream_response( self._raw_stream_response, self._usage_limits, self.usage @@ -183,10 +193,10 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _stream_response: models.StreamedResponse _output_schema: OutputSchema[OutputDataT] _run_ctx: RunContext[AgentDepsT] - _trace_ctx: TraceContext _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] + _toolset: ToolManager[AgentDepsT] _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) @@ -420,40 +430,43 @@ async def validate_structured_output( self, message: _messages.ModelResponse, *, allow_partial: bool = False ) -> OutputDataT: """Validate a structured result message.""" - call = None if isinstance(self._output_schema, ToolOutputSchema) and self._output_tool_name is not None: - match = self._output_schema.find_named_tool(message.parts, self._output_tool_name) - if match is None: + tool_call = next( + ( + part + for part in message.parts + if isinstance(part, _messages.ToolCallPart) and part.tool_name == self._output_tool_name + ), + None, + ) + if tool_call is None: raise exceptions.UnexpectedModelBehavior( # pragma: no cover - f'Invalid response, unable to find tool: {self._output_schema.tool_names()}' + f'Invalid response, unable to find tool call for {self._output_tool_name!r}' ) - - call, output_tool = match - result_data = await output_tool.process( - call, - self._run_ctx, - self._trace_ctx, - allow_partial=allow_partial, - wrap_validation_errors=False, - ) + return await self._toolset.handle_call(tool_call, allow_partial=allow_partial) + elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + if not self._output_schema.allows_deferred_tool_calls: + raise exceptions.UserError( + 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + ) + return cast(OutputDataT, deferred_tool_calls) elif isinstance(self._output_schema, TextOutputSchema): text = '\n\n'.join(x.content for x in message.parts if isinstance(x, _messages.TextPart)) result_data = await self._output_schema.process( - text, self._run_ctx, self._trace_ctx, allow_partial=allow_partial, wrap_validation_errors=False + text, self._run_ctx, allow_partial=allow_partial, wrap_validation_errors=False ) + for validator in self._output_validators: + result_data = await validator.validate(result_data, self._run_ctx) # pragma: no cover + return result_data else: raise exceptions.UnexpectedModelBehavior( # pragma: no cover 'Invalid response, unable to process text output' ) - for validator in self._output_validators: - result_data = await validator.validate(result_data, call, self._run_ctx) # pragma: no cover - return result_data - async def _validate_text_output(self, text: str) -> str: for validator in self._output_validators: - text = await validator.validate(text, None, self._run_ctx) # pragma: no cover + text = await validator.validate(text, self._run_ctx) # pragma: no cover return text async def _marked_completed(self, message: _messages.ModelResponse) -> None: diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index bbc8d83209..4243c02971 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -1,20 +1,15 @@ from __future__ import annotations as _annotations -import dataclasses -import json from collections.abc import Awaitable, Sequence from dataclasses import dataclass, field from typing import Any, Callable, Generic, Literal, Union -from opentelemetry.trace import Tracer -from pydantic import ValidationError from pydantic.json_schema import GenerateJsonSchema, JsonSchemaValue from pydantic_core import SchemaValidator, core_schema from typing_extensions import Concatenate, ParamSpec, Self, TypeAlias, TypeVar -from . import _function_schema, _utils, messages as _messages +from . import _function_schema, _utils from ._run_context import AgentDepsT, RunContext -from .exceptions import ModelRetry, UnexpectedModelBehavior __all__ = ( 'AgentDepsT', @@ -32,7 +27,6 @@ 'ToolDefinition', ) -from .messages import ToolReturnPart ToolParams = ParamSpec('ToolParams', default=...) """Retrieval function param spec.""" @@ -173,12 +167,6 @@ class Tool(Generic[AgentDepsT]): This schema may be modified by the `prepare` function or by the Model class prior to including it in an API request. """ - # TODO: Consider moving this current_retry state to live on something other than the tool. - # We've worked around this for now by copying instances of the tool when creating new runs, - # but this is a bit fragile. Moving the tool retry counts to live on the agent run state would likely clean things - # up, though is also likely a larger effort to refactor. - current_retry: int = field(default=0, init=False) - def __init__( self, function: ToolFuncEither[AgentDepsT], @@ -303,6 +291,15 @@ def from_schema( function_schema=function_schema, ) + @property + def tool_def(self): + return ToolDefinition( + name=self.name, + description=self.description, + parameters_json_schema=self.function_schema.json_schema, + strict=self.strict, + ) + async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition | None: """Get the tool definition. @@ -312,113 +309,11 @@ async def prepare_tool_def(self, ctx: RunContext[AgentDepsT]) -> ToolDefinition Returns: return a `ToolDefinition` or `None` if the tools should not be registered for this run. """ - tool_def = ToolDefinition( - name=self.name, - description=self.description, - parameters_json_schema=self.function_schema.json_schema, - strict=self.strict, - ) + base_tool_def = self.tool_def if self.prepare is not None: - return await self.prepare(ctx, tool_def) + return await self.prepare(ctx, base_tool_def) else: - return tool_def - - async def run( - self, - message: _messages.ToolCallPart, - run_context: RunContext[AgentDepsT], - tracer: Tracer, - include_content: bool = False, - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - """Run the tool function asynchronously. - - This method wraps `_run` in an OpenTelemetry span. - - See . - """ - span_attributes = { - 'gen_ai.tool.name': self.name, - # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai - 'gen_ai.tool.call.id': message.tool_call_id, - **({'tool_arguments': message.args_as_json_str()} if include_content else {}), - 'logfire.msg': f'running tool: {self.name}', - # add the JSON schema so these attributes are formatted nicely in Logfire - 'logfire.json_schema': json.dumps( - { - 'type': 'object', - 'properties': { - **( - { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - } - if include_content - else {} - ), - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ), - } - with tracer.start_as_current_span('running tool', attributes=span_attributes) as span: - response = await self._run(message, run_context) - if include_content and span.is_recording(): - span.set_attribute( - 'tool_response', - response.model_response_str() - if isinstance(response, ToolReturnPart) - else response.model_response(), - ) - - return response - - async def _run( - self, message: _messages.ToolCallPart, run_context: RunContext[AgentDepsT] - ) -> _messages.ToolReturnPart | _messages.RetryPromptPart: - try: - validator = self.function_schema.validator - if isinstance(message.args, str): - args_dict = validator.validate_json(message.args or '{}') - else: - args_dict = validator.validate_python(message.args or {}) - except ValidationError as e: - return self._on_error(e, message) - - ctx = dataclasses.replace( - run_context, - retry=self.current_retry, - tool_name=message.tool_name, - tool_call_id=message.tool_call_id, - ) - try: - response_content = await self.function_schema.call(args_dict, ctx) - except ModelRetry as e: - return self._on_error(e, message) - - self.current_retry = 0 - return _messages.ToolReturnPart( - tool_name=message.tool_name, - content=response_content, - tool_call_id=message.tool_call_id, - ) - - def _on_error( - self, exc: ValidationError | ModelRetry, call_message: _messages.ToolCallPart - ) -> _messages.RetryPromptPart: - self.current_retry += 1 - if self.max_retries is None or self.current_retry > self.max_retries: - raise UnexpectedModelBehavior(f'Tool exceeded max retries count of {self.max_retries}') from exc - else: - if isinstance(exc, ValidationError): - content = exc.errors(include_url=False, include_context=False) - else: - content = exc.message - return _messages.RetryPromptPart( - tool_name=call_message.tool_name, - content=content, - tool_call_id=call_message.tool_call_id, - ) + return base_tool_def ObjectJsonSchema: TypeAlias = dict[str, Any] @@ -429,6 +324,9 @@ def _on_error( With PEP-728 this should be a TypedDict with `type: Literal['object']`, and `extra_parts=Any` """ +ToolKind: TypeAlias = Literal['function', 'output', 'deferred'] +"""Kind of tool.""" + @dataclass(repr=False) class ToolDefinition: @@ -440,7 +338,7 @@ class ToolDefinition: name: str """The name of the tool.""" - parameters_json_schema: ObjectJsonSchema + parameters_json_schema: ObjectJsonSchema = field(default_factory=lambda: {'type': 'object', 'properties': {}}) """The JSON schema for the tool's parameters.""" description: str | None = None @@ -464,4 +362,13 @@ class ToolDefinition: Note: this is currently only supported by OpenAI models. """ + kind: ToolKind = field(default='function') + """The kind of tool: + + - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model + - `'output'`: a tool that passes through an output value that ends the run + - `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call. + """ + __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py new file mode 100644 index 0000000000..f3d3d362dc --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/__init__.py @@ -0,0 +1,22 @@ +from .abstract import AbstractToolset, ToolsetTool +from .combined import CombinedToolset +from .deferred import DeferredToolset +from .filtered import FilteredToolset +from .function import FunctionToolset +from .prefixed import PrefixedToolset +from .prepared import PreparedToolset +from .renamed import RenamedToolset +from .wrapper import WrapperToolset + +__all__ = ( + 'AbstractToolset', + 'ToolsetTool', + 'CombinedToolset', + 'DeferredToolset', + 'FilteredToolset', + 'FunctionToolset', + 'PrefixedToolset', + 'RenamedToolset', + 'PreparedToolset', + 'WrapperToolset', +) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py new file mode 100644 index 0000000000..0f19eec3bc --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -0,0 +1,165 @@ +from __future__ import annotations + +from abc import ABC, abstractmethod +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, TypeVar + +from pydantic_core import SchemaValidator +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition, ToolsPrepareFunc + +if TYPE_CHECKING: + from .filtered import FilteredToolset + from .prefixed import PrefixedToolset + from .prepared import PreparedToolset + from .renamed import RenamedToolset + from .wrapper import WrapperToolset + +WrapperT = TypeVar('WrapperT', bound='WrapperToolset[Any]') + + +class SchemaValidatorProt(Protocol): + """Protocol for a Pydantic Core `SchemaValidator` or `PluggableSchemaValidator` (which is private but API-compatible).""" + + def validate_json( + self, + input: str | bytes | bytearray, + *, + allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, + **kwargs: Any, + ) -> Any: ... + + def validate_python( + self, input: Any, *, allow_partial: bool | Literal['off', 'on', 'trailing-strings'] = False, **kwargs: Any + ) -> Any: ... + + +@dataclass +class ToolsetTool(Generic[AgentDepsT]): + """Definition of a tool available on a toolset. + + This is a wrapper around a plain tool definition that includes information about: + + - the toolset that provided it, for use in error messages + - the maximum number of retries to attempt if the tool call fails + - the validator for the tool's arguments + """ + + toolset: AbstractToolset[AgentDepsT] + """The toolset that provided this tool, for use in error messages.""" + tool_def: ToolDefinition + """The tool definition for this tool, including the name, description, and parameters.""" + max_retries: int + """The maximum number of retries to attempt if the tool call fails.""" + args_validator: SchemaValidator | SchemaValidatorProt + """The Pydantic Core validator for the tool's arguments. + + For example, a [`pydantic.TypeAdapter(...).validator`](https://docs.pydantic.dev/latest/concepts/type_adapter/) or [`pydantic_core.SchemaValidator`](https://docs.pydantic.dev/latest/api/pydantic_core/#pydantic_core.SchemaValidator). + """ + + +class AbstractToolset(ABC, Generic[AgentDepsT]): + """A toolset is a collection of tools that can be used by an agent. + + It is responsible for: + + - Listing the tools it contains + - Validating the arguments of the tools + - Calling the tools + + See [toolset docs](../toolsets.md) for more information. + """ + + @property + def name(self) -> str: + """The name of the toolset for use in error messages.""" + return self.__class__.__name__.replace('Toolset', ' toolset') + + @property + def tool_name_conflict_hint(self) -> str: + """A hint for how to avoid name conflicts with other toolsets for use in error messages.""" + return 'Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.' + + async def __aenter__(self) -> Self: + """Enter the toolset context. + + This is where you can set up network connections in a concrete implementation. + """ + return self + + async def __aexit__(self, *args: Any) -> bool | None: + """Exit the toolset context. + + This is where you can tear down network connections in a concrete implementation. + """ + return None + + @abstractmethod + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + """The tools that are available in this toolset.""" + raise NotImplementedError() + + @abstractmethod + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + """Call a tool with the given arguments. + + Args: + name: The name of the tool to call. + tool_args: The arguments to pass to the tool. + ctx: The run context. + tool: The tool definition returned by [`get_tools`][pydantic_ai.toolsets.AbstractToolset.get_tools] that was called. + """ + raise NotImplementedError() + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling).""" + return visitor(self) + + def filtered( + self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] + ) -> FilteredToolset[AgentDepsT]: + """Returns a new toolset that filters this toolset's tools using a filter function that takes the agent context and the tool definition. + + See [toolset docs](../toolsets.md#filtering-tools) for more information. + """ + from .filtered import FilteredToolset + + return FilteredToolset(self, filter_func) + + def prefixed(self, prefix: str) -> PrefixedToolset[AgentDepsT]: + """Returns a new toolset that prefixes the names of this toolset's tools. + + See [toolset docs](../toolsets.md#prefixing-tool-names) for more information. + """ + from .prefixed import PrefixedToolset + + return PrefixedToolset(self, prefix) + + def prepared(self, prepare_func: ToolsPrepareFunc[AgentDepsT]) -> PreparedToolset[AgentDepsT]: + """Returns a new toolset that prepares this toolset's tools using a prepare function that takes the agent context and the original tool definitions. + + See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information. + """ + from .prepared import PreparedToolset + + return PreparedToolset(self, prepare_func) + + def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]: + """Returns a new toolset that renames this toolset's tools using a dictionary mapping new names to original names. + + See [toolset docs](../toolsets.md#renaming-tools) for more information. + """ + from .renamed import RenamedToolset + + return RenamedToolset(self, name_map) + + def wrap(self, wrapper_cls: type[WrapperT], *args: Any, **kwargs: Any) -> WrapperT: + """Returns an instance of the provided wrapper class wrapping this toolset, with all arguments passed to the wrapper class constructor. + + See [toolset docs](../toolsets.md#wrapping-a-toolset) for more information. + """ + return wrapper_cls(self, *args, **kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py new file mode 100644 index 0000000000..a083477196 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -0,0 +1,88 @@ +from __future__ import annotations + +import asyncio +from collections.abc import Sequence +from contextlib import AsyncExitStack +from dataclasses import dataclass, field +from typing import Any, Callable + +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from .._utils import get_async_lock +from ..exceptions import UserError +from .abstract import AbstractToolset, ToolsetTool + + +@dataclass +class _CombinedToolsetTool(ToolsetTool[AgentDepsT]): + """A tool definition for a combined toolset tools that keeps track of the source toolset and tool.""" + + source_toolset: AbstractToolset[AgentDepsT] + source_tool: ToolsetTool[AgentDepsT] + + +@dataclass +class CombinedToolset(AbstractToolset[AgentDepsT]): + """A toolset that combines multiple toolsets. + + See [toolset docs](../toolsets.md#combining-toolsets) for more information. + """ + + toolsets: Sequence[AbstractToolset[AgentDepsT]] + + _enter_lock: asyncio.Lock = field(compare=False, init=False) + _entered_count: int = field(init=False) + _exit_stack: AsyncExitStack | None = field(init=False) + + def __post_init__(self): + self._enter_lock = get_async_lock() + self._entered_count = 0 + self._exit_stack = None + + async def __aenter__(self) -> Self: + async with self._enter_lock: + if self._entered_count == 0: + self._exit_stack = AsyncExitStack() + for toolset in self.toolsets: + await self._exit_stack.enter_async_context(toolset) + self._entered_count += 1 + return self + + async def __aexit__(self, *args: Any) -> bool | None: + async with self._enter_lock: + self._entered_count -= 1 + if self._entered_count == 0 and self._exit_stack is not None: + await self._exit_stack.aclose() + self._exit_stack = None + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + toolsets_tools = await asyncio.gather(*(toolset.get_tools(ctx) for toolset in self.toolsets)) + all_tools: dict[str, ToolsetTool[AgentDepsT]] = {} + + for toolset, tools in zip(self.toolsets, toolsets_tools): + for name, tool in tools.items(): + if existing_tools := all_tools.get(name): + raise UserError( + f'{toolset.name} defines a tool whose name conflicts with existing tool from {existing_tools.toolset.name}: {name!r}. {toolset.tool_name_conflict_hint}' + ) + + all_tools[name] = _CombinedToolsetTool( + toolset=tool.toolset, + tool_def=tool.tool_def, + max_retries=tool.max_retries, + args_validator=tool.args_validator, + source_toolset=toolset, + source_tool=tool, + ) + return all_tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + assert isinstance(tool, _CombinedToolsetTool) + return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + for toolset in self.toolsets: + toolset.apply(visitor) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py new file mode 100644 index 0000000000..29964e9333 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -0,0 +1,38 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from pydantic_core import SchemaValidator, core_schema + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .abstract import AbstractToolset, ToolsetTool + +TOOL_SCHEMA_VALIDATOR = SchemaValidator(schema=core_schema.any_schema()) + + +@dataclass +class DeferredToolset(AbstractToolset[AgentDepsT]): + """A toolset that holds deferred tools that will be called by the upstream service that called the agent. + + See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information. + """ + + tool_defs: list[ToolDefinition] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + tool_def.name: ToolsetTool( + toolset=self, + tool_def=replace(tool_def, kind='deferred'), + max_retries=0, + args_validator=TOOL_SCHEMA_VALIDATOR, + ) + for tool_def in self.tool_defs + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + raise NotImplementedError('Deferred tools cannot be called') diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py new file mode 100644 index 0000000000..3ff98c8ec5 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/filtered.py @@ -0,0 +1,24 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Callable + +from .._run_context import AgentDepsT, RunContext +from ..tools import ToolDefinition +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class FilteredToolset(WrapperToolset[AgentDepsT]): + """A toolset that filters the tools it contains using a filter function that takes the agent context and the tool definition. + + See [toolset docs](../toolsets.md#filtering-tools) for more information. + """ + + filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + name: tool for name, tool in (await super().get_tools(ctx)).items() if self.filter_func(ctx, tool.tool_def) + } diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/function.py b/pydantic_ai_slim/pydantic_ai/toolsets/function.py new file mode 100644 index 0000000000..63f44a1f0c --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/function.py @@ -0,0 +1,238 @@ +from __future__ import annotations + +from collections.abc import Awaitable, Sequence +from dataclasses import dataclass, field, replace +from typing import Any, Callable, overload + +from pydantic.json_schema import GenerateJsonSchema + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ( + DocstringFormat, + GenerateToolJsonSchema, + Tool, + ToolFuncEither, + ToolParams, + ToolPrepareFunc, +) +from .abstract import AbstractToolset, ToolsetTool + + +@dataclass +class _FunctionToolsetTool(ToolsetTool[AgentDepsT]): + """A tool definition for a function toolset tool that keeps track of the function to call.""" + + call_func: Callable[[dict[str, Any], RunContext[AgentDepsT]], Awaitable[Any]] + + +@dataclass(init=False) +class FunctionToolset(AbstractToolset[AgentDepsT]): + """A toolset that lets Python functions be used as tools. + + See [toolset docs](../toolsets.md#function-toolset) for more information. + """ + + max_retries: int = field(default=1) + tools: dict[str, Tool[Any]] = field(default_factory=dict) + + def __init__(self, tools: Sequence[Tool[AgentDepsT] | ToolFuncEither[AgentDepsT, ...]] = [], max_retries: int = 1): + """Build a new function toolset. + + Args: + tools: The tools to add to the toolset. + max_retries: The maximum number of retries for each tool during a run. + """ + self.max_retries = max_retries + self.tools = {} + for tool in tools: + if isinstance(tool, Tool): + self.add_tool(tool) + else: + self.add_function(tool) + + @overload + def tool(self, func: ToolFuncEither[AgentDepsT, ToolParams], /) -> ToolFuncEither[AgentDepsT, ToolParams]: ... + + @overload + def tool( + self, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Callable[[ToolFuncEither[AgentDepsT, ToolParams]], ToolFuncEither[AgentDepsT, ToolParams]]: ... + + def tool( + self, + func: ToolFuncEither[AgentDepsT, ToolParams] | None = None, + /, + *, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> Any: + """Decorator to register a tool function which takes [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. + + Can decorate a sync or async functions. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + We can't add overloads for every possible signature of tool, since the return type is a recursive union + so the signature of functions decorated with `@toolset.tool` is obscured. + + Example: + ```python + from pydantic_ai import Agent, RunContext + from pydantic_ai.toolsets.function import FunctionToolset + + toolset = FunctionToolset() + + @toolset.tool + def foobar(ctx: RunContext[int], x: int) -> int: + return ctx.deps + x + + @toolset.tool(retries=2) + async def spam(ctx: RunContext[str], y: float) -> float: + return ctx.deps + y + + agent = Agent('test', toolsets=[toolset], deps_type=int) + result = agent.run_sync('foobar', deps=1) + print(result.output) + #> {"foobar":1,"spam":1.0} + ``` + + Args: + func: The tool function to register. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + + def tool_decorator( + func_: ToolFuncEither[AgentDepsT, ToolParams], + ) -> ToolFuncEither[AgentDepsT, ToolParams]: + # noinspection PyTypeChecker + self.add_function( + func_, + None, + name, + retries, + prepare, + docstring_format, + require_parameter_descriptions, + schema_generator, + strict, + ) + return func_ + + return tool_decorator if func is None else tool_decorator(func) + + def add_function( + self, + func: ToolFuncEither[AgentDepsT, ToolParams], + takes_ctx: bool | None = None, + name: str | None = None, + retries: int | None = None, + prepare: ToolPrepareFunc[AgentDepsT] | None = None, + docstring_format: DocstringFormat = 'auto', + require_parameter_descriptions: bool = False, + schema_generator: type[GenerateJsonSchema] = GenerateToolJsonSchema, + strict: bool | None = None, + ) -> None: + """Add a function as a tool to the toolset. + + Can take a sync or async function. + + The docstring is inspected to extract both the tool description and description of each parameter, + [learn more](../tools.md#function-tools-and-schema). + + Args: + func: The tool function to register. + takes_ctx: Whether the function takes a [`RunContext`][pydantic_ai.tools.RunContext] as its first argument. If `None`, this is inferred from the function signature. + name: The name of the tool, defaults to the function name. + retries: The number of retries to allow for this tool, defaults to the agent's default retries, + which defaults to 1. + prepare: custom method to prepare the tool definition for each step, return `None` to omit this + tool from a given step. This is useful if you want to customise a tool at call time, + or omit it completely from a step. See [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc]. + docstring_format: The format of the docstring, see [`DocstringFormat`][pydantic_ai.tools.DocstringFormat]. + Defaults to `'auto'`, such that the format is inferred from the structure of the docstring. + require_parameter_descriptions: If True, raise an error if a parameter description is missing. Defaults to False. + schema_generator: The JSON schema generator class to use for this tool. Defaults to `GenerateToolJsonSchema`. + strict: Whether to enforce JSON schema compliance (only affects OpenAI). + See [`ToolDefinition`][pydantic_ai.tools.ToolDefinition] for more info. + """ + tool = Tool[AgentDepsT]( + func, + takes_ctx=takes_ctx, + name=name, + max_retries=retries, + prepare=prepare, + docstring_format=docstring_format, + require_parameter_descriptions=require_parameter_descriptions, + schema_generator=schema_generator, + strict=strict, + ) + self.add_tool(tool) + + def add_tool(self, tool: Tool[AgentDepsT]) -> None: + """Add a tool to the toolset. + + Args: + tool: The tool to add. + """ + if tool.name in self.tools: + raise UserError(f'Tool name conflicts with existing tool: {tool.name!r}') + if tool.max_retries is None: + tool.max_retries = self.max_retries + self.tools[tool.name] = tool + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + tools: dict[str, ToolsetTool[AgentDepsT]] = {} + for original_name, tool in self.tools.items(): + run_context = replace(ctx, tool_name=original_name, retry=ctx.retries.get(original_name, 0)) + tool_def = await tool.prepare_tool_def(run_context) + if not tool_def: + continue + + new_name = tool_def.name + if new_name in tools: + if new_name != original_name: + raise UserError(f'Renaming tool {original_name!r} to {new_name!r} conflicts with existing tool.') + else: + raise UserError(f'Tool name conflicts with previously renamed tool: {new_name!r}.') + + tools[new_name] = _FunctionToolsetTool( + toolset=self, + tool_def=tool_def, + max_retries=tool.max_retries if tool.max_retries is not None else self.max_retries, + args_validator=tool.function_schema.validator, + call_func=tool.function_schema.call, + ) + return tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + assert isinstance(tool, _FunctionToolsetTool) + return await tool.call_func(tool_args, ctx) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py new file mode 100644 index 0000000000..be70ed4f0f --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prefixed.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from .._run_context import AgentDepsT, RunContext +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class PrefixedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prefixes the names of the tools it contains. + + See [toolset docs](../toolsets.md#prefixing-tool-names) for more information. + """ + + prefix: str + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return { + new_name: replace( + tool, + toolset=self, + tool_def=replace(tool.tool_def, name=new_name), + ) + for name, tool in (await super().get_tools(ctx)).items() + if (new_name := f'{self.prefix}_{name}') + } + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + original_name = name.removeprefix(self.prefix + '_') + ctx = replace(ctx, tool_name=original_name) + tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name)) + return await super().call_tool(original_name, tool_args, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py new file mode 100644 index 0000000000..af604d4328 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/prepared.py @@ -0,0 +1,36 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace + +from .._run_context import AgentDepsT, RunContext +from ..exceptions import UserError +from ..tools import ToolsPrepareFunc +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class PreparedToolset(WrapperToolset[AgentDepsT]): + """A toolset that prepares the tools it contains using a prepare function that takes the agent context and the original tool definitions. + + See [toolset docs](../toolsets.md#preparing-tool-definitions) for more information. + """ + + prepare_func: ToolsPrepareFunc[AgentDepsT] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + original_tools = await super().get_tools(ctx) + original_tool_defs = [tool.tool_def for tool in original_tools.values()] + prepared_tool_defs_by_name = { + tool_def.name: tool_def for tool_def in (await self.prepare_func(ctx, original_tool_defs) or []) + } + + if len(prepared_tool_defs_by_name.keys() - original_tools.keys()) > 0: + raise UserError( + 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' + ) + + return { + name: replace(original_tools[name], tool_def=tool_def) + for name, tool_def in prepared_tool_defs_by_name.items() + } diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/renamed.py b/pydantic_ai_slim/pydantic_ai/toolsets/renamed.py new file mode 100644 index 0000000000..c0d8aff7a0 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/renamed.py @@ -0,0 +1,42 @@ +from __future__ import annotations + +from dataclasses import dataclass, replace +from typing import Any + +from .._run_context import AgentDepsT, RunContext +from .abstract import ToolsetTool +from .wrapper import WrapperToolset + + +@dataclass +class RenamedToolset(WrapperToolset[AgentDepsT]): + """A toolset that renames the tools it contains using a dictionary mapping new names to original names. + + See [toolset docs](../toolsets.md#renaming-tools) for more information. + """ + + name_map: dict[str, str] + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + original_to_new_name_map = {v: k for k, v in self.name_map.items()} + original_tools = await super().get_tools(ctx) + tools: dict[str, ToolsetTool[AgentDepsT]] = {} + for original_name, tool in original_tools.items(): + new_name = original_to_new_name_map.get(original_name, None) + if new_name: + tools[new_name] = replace( + tool, + toolset=self, + tool_def=replace(tool.tool_def, name=new_name), + ) + else: + tools[original_name] = tool + return tools + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + original_name = self.name_map.get(name, name) + ctx = replace(ctx, tool_name=original_name) + tool = replace(tool, tool_def=replace(tool.tool_def, name=original_name)) + return await super().call_tool(original_name, tool_args, ctx, tool) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py new file mode 100644 index 0000000000..1dddd96a51 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -0,0 +1,37 @@ +from __future__ import annotations + +from dataclasses import dataclass +from typing import Any, Callable + +from typing_extensions import Self + +from .._run_context import AgentDepsT, RunContext +from .abstract import AbstractToolset, ToolsetTool + + +@dataclass +class WrapperToolset(AbstractToolset[AgentDepsT]): + """A toolset that wraps another toolset and delegates to it. + + See [toolset docs](../toolsets.md#wrapping-a-toolset) for more information. + """ + + wrapped: AbstractToolset[AgentDepsT] + + async def __aenter__(self) -> Self: + await self.wrapped.__aenter__() + return self + + async def __aexit__(self, *args: Any) -> bool | None: + return await self.wrapped.__aexit__(*args) + + async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[AgentDepsT]]: + return await self.wrapped.get_tools(ctx) + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[AgentDepsT], tool: ToolsetTool[AgentDepsT] + ) -> Any: + return await self.wrapped.call_tool(name, tool_args, ctx, tool) + + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + return self.wrapped.apply(visitor) diff --git a/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml new file mode 100644 index 0000000000..e33e36f96e --- /dev/null +++ b/tests/cassettes/test_mcp/test_agent_with_server_not_running.yaml @@ -0,0 +1,391 @@ +interactions: +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2501' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1086' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '420' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + created: 1751491994 + id: chatcmpl-BozMoBhgfC5D8QBjkiOwz5OxxrwQK + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 18 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 268 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 286 + status: + code: 200 + message: OK +- request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2748' + content-type: + - application/json + cookie: + - __cf_bm=JOV7WG2Y48FZrZxdh0IZvA9mCj_ljIN3DhGMuC1pw6M-1751491995-1.0.1.1-zGPrLbzYx7y3iZT28xogbHO1KAIej60kPEwQ8ZxGMxv1r.ICtqI0T8WCnlyUccKfLSXB6ZTNQT05xCma8LSvq2pk4X2eEuSkYC1sPqbuLU8; + _cfuvid=LdoyX0uKYwM98NSSSvySlZAiJHCVHz_1krUGKbWmNHg-1751491995391-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: What is 0 degrees Celsius in Fahrenheit? + role: user + - role: assistant + tool_calls: + - function: + arguments: '{"celsius":0}' + name: celsius_to_fahrenheit + id: call_hS0oexgCNI6TneJuPPuwn9jQ + type: function + - content: '32.0' + role: tool + tool_call_id: call_hS0oexgCNI6TneJuPPuwn9jQ + model: gpt-4o + stream: false + tool_choice: auto + tools: + - function: + description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_audio_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + - function: + description: "Echo the run context.\n\n Args:\n ctx: Context object containing request and session information.\n\n + \ Returns:\n Dictionary with an echo message and the deps.\n " + name: echo_deps + parameters: + properties: {} + type: object + type: function + - function: + description: Use sampling callback. + name: use_sampling + parameters: + properties: + foo: + type: string + required: + - foo + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '849' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '520' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: 0 degrees Celsius is 32.0 degrees Fahrenheit. + refusal: null + role: assistant + created: 1751491998 + id: chatcmpl-BozMsevK8quJblNOyNCaDQpdtDwI5 + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_a288987b44 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 300 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 312 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/ext/test_langchain.py b/tests/ext/test_langchain.py index 73e7cc0504..926a228194 100644 --- a/tests/ext/test_langchain.py +++ b/tests/ext/test_langchain.py @@ -6,7 +6,7 @@ from pydantic.json_schema import JsonSchemaValue from pydantic_ai import Agent -from pydantic_ai.ext.langchain import tool_from_langchain +from pydantic_ai.ext.langchain import LangChainToolset, tool_from_langchain @dataclass @@ -49,24 +49,26 @@ def get_input_jsonschema(self) -> JsonSchemaValue: } -def test_langchain_tool_conversion(): - langchain_tool = SimulatedLangChainTool( - name='file_search', - description='Recursively search for files in a subdirectory that match the regex pattern', - args={ - 'dir_path': { - 'default': '.', - 'description': 'Subdirectory to search in.', - 'title': 'Dir Path', - 'type': 'string', - }, - 'pattern': { - 'description': 'Unix shell regex, where * matches everything.', - 'title': 'Pattern', - 'type': 'string', - }, +langchain_tool = SimulatedLangChainTool( + name='file_search', + description='Recursively search for files in a subdirectory that match the regex pattern', + args={ + 'dir_path': { + 'default': '.', + 'description': 'Subdirectory to search in.', + 'title': 'Dir Path', + 'type': 'string', }, - ) + 'pattern': { + 'description': 'Unix shell regex, where * matches everything.', + 'title': 'Pattern', + 'type': 'string', + }, + }, +) + + +def test_langchain_tool_conversion(): pydantic_tool = tool_from_langchain(langchain_tool) agent = Agent('test', tools=[pydantic_tool], retries=7) @@ -74,6 +76,13 @@ def test_langchain_tool_conversion(): assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") +def test_langchain_toolset(): + toolset = LangChainToolset([langchain_tool]) + agent = Agent('test', toolsets=[toolset], retries=7) + result = agent.run_sync('foobar') + assert result.output == snapshot("{\"file_search\":\"I was called with {'dir_path': '.', 'pattern': 'a'}\"}") + + def test_langchain_tool_no_additional_properties(): langchain_tool = SimulatedLangChainTool( name='file_search', diff --git a/tests/models/test_anthropic.py b/tests/models/test_anthropic.py index 3891c5108c..77857e8821 100644 --- a/tests/models/test_anthropic.py +++ b/tests/models/test_anthropic.py @@ -1700,7 +1700,7 @@ class CityLocation(BaseModel): agent = Agent(m, output_type=NativeOutput(CityLocation)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): await agent.run('What is the largest city in the user country?') diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index a84c49d869..0b95bedeb5 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -974,12 +974,47 @@ async def test_stream_text_heterogeneous(get_gemini_client: GetGeminiClient): @agent.tool_plain() def get_location(loc_name: str) -> str: - return f'Location for {loc_name}' + return f'Location for {loc_name}' # pragma: no cover async with agent.run_stream('Hello') as result: data = await result.get_output() assert data == 'Hello foo' + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + TextPart(content='Hello foo'), + ToolCallPart( + tool_name='get_location', + args={'loc_name': 'San Fransisco'}, + tool_call_id=IsStr(), + ), + ], + usage=Usage(request_tokens=1, response_tokens=2, total_tokens=3, details={}), + model_name='gemini-1.5-flash', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_location', + content='Tool not executed - a final result was already processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) async def test_empty_text_ignored(): diff --git a/tests/models/test_model_test.py b/tests/models/test_model_test.py index 31635c080d..02aafd259f 100644 --- a/tests/models/test_model_test.py +++ b/tests/models/test_model_test.py @@ -4,6 +4,7 @@ import asyncio import dataclasses +import re from datetime import timezone from typing import Annotated, Any, Literal @@ -157,7 +158,7 @@ def validate_output(ctx: RunContext[None], output: OutputModel) -> OutputModel: call_count += 1 raise ModelRetry('Fail') - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match=re.escape('Exceeded maximum retries (2) for output validation')): agent.run_sync('Hello', model=TestModel()) assert call_count == 3 @@ -200,7 +201,7 @@ class ResultModel(BaseModel): agent = Agent('test', output_type=ResultModel, retries=2) - with pytest.raises(UnexpectedModelBehavior, match='Exceeded maximum retries'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(2\) for output validation'): agent.run_sync('Hello', model=TestModel(custom_output_args={'foo': 'a', 'bar': 1})) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index c3d56728f4..535e3b1e91 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -800,8 +800,15 @@ async def test_a2a_multiple_messages(): } ) - await anyio.sleep(0.1) - task = await a2a_client.get_task(task_id) + task = None + tries = 0 + while tries < 10: # pragma: no branch + await anyio.sleep(0.1) + task = await a2a_client.get_task(task_id) + tries += 1 + if 'result' in task and task['result']['status']['state'] == 'completed': # pragma: no branch + break + assert task == snapshot( { 'jsonrpc': '2.0', diff --git a/tests/test_agent.py b/tests/test_agent.py index 47a9c01a12..c1893fdcb4 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -1,6 +1,7 @@ import json import re import sys +from dataclasses import dataclass from datetime import timezone from typing import Any, Callable, Union @@ -45,6 +46,9 @@ from pydantic_ai.profiles import ModelProfile from pydantic_ai.result import Usage from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.combined import CombinedToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset from .conftest import IsDatetime, IsNow, IsStr, TestEnv @@ -396,6 +400,7 @@ def test_response_tuple(): 'type': 'object', }, outer_typed_dict_key='response', + kind='output', ) ] ) @@ -469,6 +474,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ) ] ) @@ -548,6 +554,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Foo', 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Bar', @@ -558,6 +565,7 @@ def validate_output(ctx: RunContext[None], o: Any) -> Any: 'title': 'Bar', 'type': 'object', }, + kind='output', ), ] ) @@ -589,6 +597,7 @@ class MyOutput(BaseModel): 'title': 'MyOutput', 'type': 'object', }, + kind='output', ) ] ) @@ -635,6 +644,7 @@ class Bar(BaseModel): }, outer_typed_dict_key='response', strict=False, + kind='output', ) ] ) @@ -673,6 +683,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -712,6 +723,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -752,6 +764,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -793,6 +806,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -943,7 +957,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: [[str, str], [str, TextOutput(upcase)], [TextOutput(upcase), TextOutput(upcase)]], ) def test_output_type_multiple_text_output(output_type: OutputSpec[str]): - with pytest.raises(UserError, match='Only one text output is allowed.'): + with pytest.raises(UserError, match='Only one `str` or `TextOutput` is allowed.'): Agent('test', output_type=output_type) @@ -989,6 +1003,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1027,6 +1042,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ) ] ) @@ -1065,6 +1081,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='final_result_Weather', @@ -1075,6 +1092,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1251,6 +1269,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'required': ['city'], 'type': 'object', }, + kind='output', ), ToolDefinition( name='return_weather', @@ -1261,6 +1280,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'title': 'Weather', 'type': 'object', }, + kind='output', ), ] ) @@ -1322,6 +1342,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'type': 'object', }, description='A person', + kind='output', ), ToolDefinition( name='final_result_Animal', @@ -1332,6 +1353,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'type': 'object', }, description='An animal', + kind='output', ), ] ) @@ -1998,7 +2020,7 @@ def empty(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(FunctionModel(empty)) with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): agent.run_sync('Hello') assert messages == snapshot( [ @@ -2350,12 +2372,6 @@ def another_tool(y: int) -> int: tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, @@ -2365,6 +2381,12 @@ def another_tool(y: int) -> int: ToolReturnPart( tool_name='another_tool', content=2, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc) ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ), ] ), ] @@ -2428,16 +2450,16 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='another_tool', @@ -2447,7 +2469,7 @@ def another_tool(y: int) -> int: # pragma: no cover ), RetryPromptPart( tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), @@ -2494,11 +2516,13 @@ def return_model(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: # Verify we got appropriate tool returns assert result.new_messages()[-1].parts == snapshot( [ - ToolReturnPart( + RetryPromptPart( + content=[ + {'type': 'missing', 'loc': ('value',), 'msg': 'Field required', 'input': {'bad_value': 'first'}} + ], tool_name='final_result', tool_call_id='first', - content='Output tool not used - result failed validation.', - timestamp=IsNow(tz=timezone.utc), + timestamp=IsDatetime(), ), ToolReturnPart( tool_name='final_result', @@ -3247,7 +3271,7 @@ def hello(_: list[ModelMessage], _info: AgentInfo) -> ModelResponse: agent = Agent(model, output_type=NativeOutput(Foo)) - with pytest.raises(UserError, match='Structured output is not supported by the model.'): + with pytest.raises(UserError, match='Native structured output is not supported by the model.'): agent.run_sync('Hello') agent = Agent(model, output_type=ToolOutput(Foo)) @@ -3435,7 +3459,7 @@ def analyze_data() -> list[Any]: with pytest.raises( UserError, - match="analyze_data's return contains invalid nested ToolReturn objects. ToolReturn should be used directly.", + match="The return value of tool 'analyze_data' contains invalid nested `ToolReturn` objects. `ToolReturn` should be used directly.", ): agent.run_sync('Please analyze the data') @@ -3469,7 +3493,7 @@ def analyze_data() -> ToolReturn: with pytest.raises( UserError, - match="analyze_data's `return_value` contains invalid nested MultiModalContentTypes objects. Please use `content` instead.", + match="The `return_value` of tool 'analyze_data' contains invalid nested `MultiModalContentTypes` objects. Please use `content` instead.", ): agent.run_sync('Please analyze the data') @@ -3534,6 +3558,19 @@ def test_deprecated_kwargs_still_work(): assert issubclass(w[0].category, DeprecationWarning) assert '`result_retries` is deprecated' in str(w[0].message) + try: + from pydantic_ai.mcp import MCPServerStdio + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter('always') + + agent = Agent('test', mcp_servers=[MCPServerStdio('python', ['-m', 'tests.mcp_server'])]) # type: ignore[call-arg] + assert len(w) == 1 + assert issubclass(w[0].category, DeprecationWarning) + assert '`mcp_servers` is deprecated' in str(w[0].message) + except ImportError: + pass + def test_deprecated_kwargs_mixed_valid_invalid(): """Test that mix of valid deprecated and invalid kwargs raises error for invalid ones.""" @@ -3548,3 +3585,272 @@ def test_deprecated_kwargs_mixed_valid_invalid(): with warnings.catch_warnings(): warnings.simplefilter('ignore', DeprecationWarning) # Ignore the deprecation warning for result_tool_name Agent('test', result_tool_name='test', foo='value1', bar='value2') # type: ignore[call-arg] + + +def test_override_toolsets(): + foo_toolset = FunctionToolset() + + @foo_toolset.tool + def foo() -> str: + return 'Hello from foo' + + available_tools: list[list[str]] = [] + + async def prepare_tools(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + nonlocal available_tools + available_tools.append([tool_def.name for tool_def in tool_defs]) + return tool_defs + + agent = Agent('test', toolsets=[foo_toolset], prepare_tools=prepare_tools) + + @agent.tool_plain + def baz() -> str: + return 'Hello from baz' + + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz', 'foo']) + assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo"}') + + bar_toolset = FunctionToolset() + + @bar_toolset.tool + def bar() -> str: + return 'Hello from bar' + + with agent.override(toolsets=[bar_toolset]): + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz', 'bar']) + assert result.output == snapshot('{"baz":"Hello from baz","bar":"Hello from bar"}') + + with agent.override(toolsets=[]): + result = agent.run_sync('Hello') + assert available_tools[-1] == snapshot(['baz']) + assert result.output == snapshot('{"baz":"Hello from baz"}') + + result = agent.run_sync('Hello', toolsets=[bar_toolset]) + assert available_tools[-1] == snapshot(['baz', 'foo', 'bar']) + assert result.output == snapshot('{"baz":"Hello from baz","foo":"Hello from foo","bar":"Hello from bar"}') + + with agent.override(toolsets=[]): + result = agent.run_sync('Hello', toolsets=[bar_toolset]) + assert available_tools[-1] == snapshot(['baz']) + assert result.output == snapshot('{"baz":"Hello from baz"}') + + +def test_adding_tools_during_run(): + toolset = FunctionToolset() + + def foo() -> str: + return 'Hello from foo' + + @toolset.tool + def add_foo_tool() -> str: + toolset.add_function(foo) + return 'foo tool added' + + def respond(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart('add_foo_tool')]) + elif len(messages) == 3: + return ModelResponse(parts=[ToolCallPart('foo')]) + else: + return ModelResponse(parts=[TextPart('Done')]) + + agent = Agent(FunctionModel(respond), toolsets=[toolset]) + result = agent.run_sync('Add the foo tool and run it') + assert result.output == snapshot('Done') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Add the foo tool and run it', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='add_foo_tool', tool_call_id=IsStr())], + usage=Usage(requests=1, request_tokens=57, response_tokens=2, total_tokens=59), + model_name='function:respond:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='add_foo_tool', + content='foo tool added', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ToolCallPart(tool_name='foo', tool_call_id=IsStr())], + usage=Usage(requests=1, request_tokens=60, response_tokens=4, total_tokens=64), + model_name='function:respond:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='foo', + content='Hello from foo', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[TextPart(content='Done')], + usage=Usage(requests=1, request_tokens=63, response_tokens=5, total_tokens=68), + model_name='function:respond:', + timestamp=IsDatetime(), + ), + ] + ) + + +def test_prepare_output_tools(): + @dataclass + class AgentDeps: + plan_presented: bool = False + + async def present_plan(ctx: RunContext[AgentDeps], plan: str) -> str: + """ + Present the plan to the user. + """ + ctx.deps.plan_presented = True + return plan + + async def run_sql(ctx: RunContext[AgentDeps], purpose: str, query: str) -> str: + """ + Run an SQL query. + """ + return 'SQL query executed successfully' + + async def only_if_plan_presented( + ctx: RunContext[AgentDeps], tool_defs: list[ToolDefinition] + ) -> list[ToolDefinition]: + return tool_defs if ctx.deps.plan_presented else [] + + agent = Agent( + model='test', + deps_type=AgentDeps, + tools=[present_plan], + output_type=[ToolOutput(run_sql, name='run_sql')], + prepare_output_tools=only_if_plan_presented, + ) + + result = agent.run_sync('Hello', deps=AgentDeps()) + assert result.output == snapshot('SQL query executed successfully') + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='Hello', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='present_plan', + args={'plan': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=51, response_tokens=5, total_tokens=56), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='present_plan', + content='a', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='run_sql', + args={'purpose': 'a', 'query': 'a'}, + tool_call_id=IsStr(), + ) + ], + usage=Usage(requests=1, request_tokens=52, response_tokens=12, total_tokens=64), + model_name='test', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='run_sql', + content='Final result processed.', + tool_call_id=IsStr(), + timestamp=IsDatetime(), + ) + ] + ), + ] + ) + + +async def test_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + agent = Agent('test', toolsets=[toolset]) + + async with agent: + assert server1.is_running + assert server2.is_running + + async with agent: + assert server1.is_running + assert server2.is_running + + +def test_set_mcp_sampling_model(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + test_model = TestModel() + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server'], sampling_model=test_model) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + agent = Agent(None, toolsets=[toolset]) + + with pytest.raises(UserError, match='No sampling model provided and no model set on the agent.'): + agent.set_mcp_sampling_model() + assert server1.sampling_model is None + assert server2.sampling_model is test_model + + agent.model = test_model + agent.set_mcp_sampling_model() + assert server1.sampling_model is test_model + assert server2.sampling_model is test_model + + function_model = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Hello')])) + with agent.override(model=function_model): + agent.set_mcp_sampling_model() + assert server1.sampling_model is function_model + assert server2.sampling_model is function_model + + function_model2 = FunctionModel(lambda messages, info: ModelResponse(parts=[TextPart('Goodbye')])) + agent.set_mcp_sampling_model(function_model2) + assert server1.sampling_model is function_model2 + assert server2.sampling_model is function_model2 diff --git a/tests/test_examples.py b/tests/test_examples.py index 5735a11ffa..0fbe64bc3a 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -21,6 +21,7 @@ from rich.console import Console from pydantic_ai import ModelHTTPError +from pydantic_ai._run_context import RunContext from pydantic_ai._utils import group_by_temporal from pydantic_ai.exceptions import UnexpectedModelBehavior from pydantic_ai.messages import ( @@ -36,6 +37,8 @@ from pydantic_ai.models.fallback import FallbackModel from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel +from pydantic_ai.toolsets import AbstractToolset +from pydantic_ai.toolsets.abstract import ToolsetTool from .conftest import ClientWithHandler, TestEnv, try_import @@ -259,18 +262,20 @@ def rich_prompt_ask(prompt: str, *_args: Any, **_kwargs: Any) -> str: raise ValueError(f'Unexpected prompt: {prompt}') -class MockMCPServer: - is_running = True - +class MockMCPServer(AbstractToolset[Any]): async def __aenter__(self) -> MockMCPServer: return self async def __aexit__(self, *args: Any) -> None: pass - @staticmethod - async def list_tools() -> list[None]: - return [] + async def get_tools(self, ctx: RunContext[Any]) -> dict[str, ToolsetTool[Any]]: + return {} + + async def call_tool( + self, name: str, tool_args: dict[str, Any], ctx: RunContext[Any], tool: ToolsetTool[Any] + ) -> Any: + return None # pragma: lax no cover text_responses: dict[str, str | ToolCallPart] = { @@ -553,6 +558,21 @@ async def model_logic( # noqa: C901 ) ] ) + elif m.content == 'Greet the user in a personalized way': + if any(t.name == 'get_preferred_language' for t in info.function_tools): + part = ToolCallPart( + tool_name='get_preferred_language', + args={'default_language': 'en-US'}, + tool_call_id='pyd_ai_tool_call_id', + ) + else: + part = ToolCallPart( + tool_name='final_result', + args={'greeting': 'Hello, David!', 'language_code': 'en-US'}, + tool_call_id='pyd_ai_tool_call_id', + ) + + return ModelResponse(parts=[part]) elif response := text_responses.get(m.content): if isinstance(response, str): return ModelResponse(parts=[TextPart(response)]) @@ -697,6 +717,16 @@ async def model_logic( # noqa: C901 ) elif isinstance(m, ToolReturnPart) and m.tool_name == 'image_generator': return ModelResponse(parts=[TextPart('Image file written to robot_punk.svg.')]) + elif isinstance(m, ToolReturnPart) and m.tool_name == 'get_preferred_language': + return ModelResponse( + parts=[ + ToolCallPart( + tool_name='final_result', + args={'greeting': 'Hola, David! Espero que tengas un gran día!', 'language_code': 'es-MX'}, + tool_call_id='pyd_ai_tool_call_id', + ) + ] + ) else: sys.stdout.write(str(debug.format(messages, info))) raise RuntimeError(f'Unexpected message: {m}') diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 97ba871cc9..799724179a 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -11,6 +11,7 @@ from pydantic_ai import Agent from pydantic_ai._utils import get_traceparent +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ModelMessage, ModelResponse, TextPart, ToolCallPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.instrumented import InstrumentationSettings, InstrumentedModel @@ -294,6 +295,7 @@ async def my_ret(x: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ], 'output_mode': 'text', @@ -641,10 +643,11 @@ async def test_feedback(capfire: CaptureLogfire) -> None: @pytest.mark.skipif(not logfire_installed, reason='logfire not installed') -@pytest.mark.parametrize('include_content', [True, False]) +@pytest.mark.parametrize('include_content,tool_error', [(True, False), (True, True), (False, False), (False, True)]) def test_include_tool_args_span_attributes( get_logfire_summary: Callable[[], LogfireSummary], include_content: bool, + tool_error: bool, ) -> None: """Test that tool arguments are included/excluded in span attributes based on instrumentation settings.""" @@ -655,61 +658,119 @@ def test_include_tool_args_span_attributes( @my_agent.tool_plain async def add_numbers(x: int, y: int) -> int: """Add two numbers together.""" + if tool_error: + raise ModelRetry('Tool error') return x + y - result = my_agent.run_sync('Add 42 and 42') - assert result.output == snapshot('{"add_numbers":84}') + try: + result = my_agent.run_sync('Add 42 and 42') + assert result.output == snapshot('{"add_numbers":84}') + except UnexpectedModelBehavior: + if not tool_error: + raise # pragma: no cover summary = get_logfire_summary() - [tool_attributes] = [ + tool_attributes = next( attributes for attributes in summary.attributes.values() if attributes.get('gen_ai.tool.name') == 'add_numbers' - ] + ) if include_content: - assert tool_attributes == snapshot( - { - 'gen_ai.tool.name': 'add_numbers', - 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"x":42,"y":42}', - 'tool_response': '84', - 'logfire.msg': 'running tool: add_numbers', - 'logfire.json_schema': IsJson( - snapshot( - { - 'type': 'object', - 'properties': { - 'tool_arguments': {'type': 'object'}, - 'tool_response': {'type': 'object'}, - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ) - ), - 'logfire.span_type': 'span', - } - ) + if tool_error: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"x":42,"y":42}', + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'tool_response': """\ +Tool error + +Fix the errors and try again.\ +""", + 'logfire.level_num': 17, + } + ) + else: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'tool_arguments': '{"x":42,"y":42}', + 'tool_response': '84', + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'tool_arguments': {'type': 'object'}, + 'tool_response': {'type': 'object'}, + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + } + ) else: - assert tool_attributes == snapshot( - { - 'gen_ai.tool.name': 'add_numbers', - 'gen_ai.tool.call.id': IsStr(), - 'logfire.msg': 'running tool: add_numbers', - 'logfire.json_schema': IsJson( - snapshot( - { - 'type': 'object', - 'properties': { - 'gen_ai.tool.name': {}, - 'gen_ai.tool.call.id': {}, - }, - } - ) - ), - 'logfire.span_type': 'span', - } - ) + if tool_error: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + 'logfire.level_num': 17, + } + ) + else: + assert tool_attributes == snapshot( + { + 'gen_ai.tool.name': 'add_numbers', + 'gen_ai.tool.call.id': IsStr(), + 'logfire.msg': 'running tool: add_numbers', + 'logfire.json_schema': IsJson( + snapshot( + { + 'type': 'object', + 'properties': { + 'gen_ai.tool.name': {}, + 'gen_ai.tool.call.id': {}, + }, + } + ) + ), + 'logfire.span_type': 'span', + } + ) class WeatherInfo(BaseModel): @@ -750,7 +811,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -811,7 +872,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -881,7 +942,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'logfire.msg': 'running output function: final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "New York City"}', + 'tool_arguments': '{"city":"New York City"}', 'logfire.json_schema': IsJson( snapshot( { @@ -900,7 +961,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: 'gen_ai.tool.name': 'final_result', 'logfire.msg': 'running output function: final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.json_schema': IsJson( snapshot( { @@ -968,7 +1029,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'get_weather', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: get_weather', 'logfire.json_schema': IsJson( snapshot( @@ -1034,7 +1095,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -1101,7 +1162,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -1163,7 +1224,7 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: { 'gen_ai.tool.name': 'final_result', 'gen_ai.tool.call.id': IsStr(), - 'tool_arguments': '{"city": "Mexico City"}', + 'tool_arguments': '{"city":"Mexico City"}', 'logfire.msg': 'running output function: final_result', 'logfire.json_schema': IsJson( snapshot( @@ -1299,7 +1360,7 @@ def call_tool(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: assert output_function_attributes == snapshot( { 'gen_ai.tool.name': 'upcase_text', - 'tool_arguments': '{"text": "hello world"}', + 'tool_arguments': '{"text":"hello world"}', 'logfire.msg': 'running output function: upcase_text', 'logfire.json_schema': IsJson( snapshot( diff --git a/tests/test_mcp.py b/tests/test_mcp.py index fe092d9dd7..e2c5bd0989 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -1,5 +1,7 @@ """Tests for the MCP (Model Context Protocol) server implementation.""" +from __future__ import annotations + import base64 import re from datetime import timezone @@ -23,6 +25,7 @@ ToolReturnPart, UserPromptPart, ) +from pydantic_ai.models import Model from pydantic_ai.models.test import TestModel from pydantic_ai.tools import RunContext from pydantic_ai.usage import Usage @@ -48,23 +51,36 @@ @pytest.fixture -def agent(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - return Agent(model, mcp_servers=[server]) +def mcp_server() -> MCPServerStdio: + return MCPServerStdio('python', ['-m', 'tests.mcp_server']) + + +@pytest.fixture +def model(openai_api_key: str) -> Model: + return OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) + +@pytest.fixture +def agent(model: Model, mcp_server: MCPServerStdio) -> Agent: + return Agent(model, toolsets=[mcp_server]) -async def test_stdio_server(): + +@pytest.fixture +def run_context(model: Model) -> RunContext[int]: + return RunContext(deps=0, model=model, usage=Usage()) + + +async def test_stdio_server(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) async with server: - tools = await server.list_tools() + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] assert len(tools) == snapshot(13) assert tools[0].name == 'celsius_to_fahrenheit' assert isinstance(tools[0].description, str) assert tools[0].description.startswith('Convert Celsius to Fahrenheit.') # Test calling the temperature conversion tool - result = await server.call_tool('celsius_to_fahrenheit', {'celsius': 0}) + result = await server.direct_call_tool('celsius_to_fahrenheit', {'celsius': 0}) assert result == snapshot('32.0') @@ -75,38 +91,43 @@ async def test_reentrant_context_manager(): pass -async def test_stdio_server_with_tool_prefix(): +async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: - tools = await server.list_tools() - assert all(tool.name.startswith('foo_') for tool in tools) + tools = await server.get_tools(run_context) + assert all(name.startswith('foo_') for name in tools.keys()) + + result = await server.call_tool( + 'foo_celsius_to_fahrenheit', {'celsius': 0}, run_context, tools['foo_celsius_to_fahrenheit'] + ) + assert result == snapshot('32.0') -async def test_stdio_server_with_cwd(): +async def test_stdio_server_with_cwd(run_context: RunContext[int]): test_dir = Path(__file__).parent server = MCPServerStdio('python', ['mcp_server.py'], cwd=test_dir) async with server: - tools = await server.list_tools() + tools = await server.get_tools(run_context) assert len(tools) == snapshot(13) -async def test_process_tool_call() -> None: +async def test_process_tool_call(run_context: RunContext[int]) -> int: called: bool = False async def process_tool_call( ctx: RunContext[int], call_tool: CallToolFunc, - tool_name: str, - args: dict[str, Any], + name: str, + tool_args: dict[str, Any], ) -> ToolResult: """A process_tool_call that sets a flag and sends deps as metadata.""" nonlocal called called = True - return await call_tool(tool_name, args, {'deps': ctx.deps}) + return await call_tool(name, tool_args, {'deps': ctx.deps}) server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], process_tool_call=process_tool_call) async with server: - agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), mcp_servers=[server]) + agent = Agent(deps_type=int, model=TestModel(call_tools=['echo_deps']), toolsets=[server]) result = await agent.run('Echo with deps set to 42', deps=42) assert result.output == snapshot('{"echo_deps":{"echo":"This is an echo message","deps":42}}') assert called, 'process_tool_call should have been called' @@ -135,7 +156,7 @@ def test_sse_server_with_header_and_timeout(): @pytest.mark.vcr() async def test_agent_with_stdio_server(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') assert result.output == snapshot('0 degrees Celsius is equal to 32 degrees Fahrenheit.') assert result.all_messages() == snapshot( @@ -212,11 +233,11 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent: with pytest.raises( UserError, match=re.escape( - "MCP Server 'MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None)' defines a tool whose name conflicts with existing tool: 'get_none'. Consider using `tool_prefix` to avoid name conflicts." + "MCPServerStdio(command='python', args=['-m', 'tests.mcp_server'], tool_prefix=None) defines a tool whose name conflicts with existing tool from Function toolset: 'get_none'. Consider setting `tool_prefix` to avoid name conflicts." ), ): await agent.run('Get me a conflict') @@ -227,7 +248,7 @@ async def test_agent_with_prefix_tool_name(openai_api_key: str): model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) agent = Agent( model, - mcp_servers=[server], + toolsets=[server], ) @agent.tool_plain @@ -235,43 +256,41 @@ def get_none() -> None: # pragma: no cover """Return nothing""" return None - async with agent.run_mcp_servers(): + async with agent: # This means that we passed the _prepare_request_parameters check and there is no conflict in the tool name with pytest.raises(RuntimeError, match='Model requests are not allowed, since ALLOW_MODEL_REQUESTS is False'): await agent.run('No conflict') -async def test_agent_with_server_not_running(openai_api_key: str): - server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) - model = OpenAIModel('gpt-4o', provider=OpenAIProvider(api_key=openai_api_key)) - agent = Agent(model, mcp_servers=[server]) - with pytest.raises(UserError, match='MCP server is not running'): - await agent.run('What is 0 degrees Celsius in Fahrenheit?') +@pytest.mark.vcr() +async def test_agent_with_server_not_running(agent: Agent, allow_model_requests: None): + result = await agent.run('What is 0 degrees Celsius in Fahrenheit?') + assert result.output == snapshot('0 degrees Celsius is 32.0 degrees Fahrenheit.') -async def test_log_level_unset(): +async def test_log_level_unset(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) assert server.log_level is None async with server: - tools = await server.list_tools() + tools = [tool.tool_def for tool in (await server.get_tools(run_context)).values()] assert len(tools) == snapshot(13) assert tools[10].name == 'get_log_level' - result = await server.call_tool('get_log_level', {}) + result = await server.direct_call_tool('get_log_level', {}) assert result == snapshot('unset') -async def test_log_level_set(): +async def test_log_level_set(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], log_level='info') assert server.log_level == 'info' async with server: - result = await server.call_tool('get_log_level', {}) + result = await server.direct_call_tool('get_log_level', {}) assert result == snapshot('info') @pytest.mark.vcr() async def test_tool_returning_str(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('What is the weather in Mexico City?') assert result.output == snapshot( 'The weather in Mexico City is currently sunny with a temperature of 26 degrees Celsius.' @@ -350,7 +369,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me the product name') assert result.output == snapshot('The product name is "PydanticAI".') assert result.all_messages() == snapshot( @@ -423,7 +442,7 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A @pytest.mark.vcr() async def test_tool_returning_image_resource(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me the image resource') assert result.output == snapshot( 'This is an image of a sliced kiwi with a vibrant green interior and black seeds.' @@ -506,7 +525,7 @@ async def test_tool_returning_audio_resource( allow_model_requests: None, agent: Agent, audio_content: BinaryContent, gemini_api_key: str ): model = GoogleModel('gemini-2.5-pro-preview-03-25', provider=GoogleProvider(api_key=gemini_api_key)) - async with agent.run_mcp_servers(): + async with agent: result = await agent.run("What's the content of the audio resource?", model=model) assert result.output == snapshot('The audio resource contains a voice saying "Hello, my name is Marcelo."') assert result.all_messages() == snapshot( @@ -557,7 +576,7 @@ async def test_tool_returning_audio_resource( @pytest.mark.vcr() async def test_tool_returning_image(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me an image') assert result.output == snapshot('Here is an image of a sliced kiwi on a white background.') assert result.all_messages() == snapshot( @@ -637,7 +656,7 @@ async def test_tool_returning_image(allow_model_requests: None, agent: Agent, im @pytest.mark.vcr() async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me a dict, respond on one line') assert result.output == snapshot('{"foo":"bar","baz":123}') assert result.all_messages() == snapshot( @@ -704,7 +723,7 @@ async def test_tool_returning_dict(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_error(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me an error, pass False as a value, unless the tool tells you otherwise') assert result.output == snapshot( 'I called the tool with the correct parameter, and it returned: "This is not an error."' @@ -818,7 +837,7 @@ async def test_tool_returning_error(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_none(allow_model_requests: None, agent: Agent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Call the none tool and say Hello') assert result.output == snapshot('Hello! How can I assist you today?') assert result.all_messages() == snapshot( @@ -885,7 +904,7 @@ async def test_tool_returning_none(allow_model_requests: None, agent: Agent): @pytest.mark.vcr() async def test_tool_returning_multiple_items(allow_model_requests: None, agent: Agent, image_content: BinaryContent): - async with agent.run_mcp_servers(): + async with agent: result = await agent.run('Get me multiple items and summarize in one sentence') assert result.output == snapshot( 'The data includes two strings, a dictionary with a key-value pair, and an image of a sliced kiwi.' @@ -974,11 +993,11 @@ async def test_tool_returning_multiple_items(allow_model_requests: None, agent: ) -async def test_client_sampling(): +async def test_client_sampling(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: - result = await server.call_tool('use_sampling', {'foo': 'bar'}) + result = await server.direct_call_tool('use_sampling', {'foo': 'bar'}) assert result == snapshot( { 'meta': None, @@ -990,27 +1009,27 @@ async def test_client_sampling(): ) -async def test_client_sampling_disabled(): +async def test_client_sampling_disabled(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], allow_sampling=False) server.sampling_model = TestModel(custom_output_text='sampling model response') async with server: with pytest.raises(ModelRetry, match='Error executing tool use_sampling: Sampling not supported'): - await server.call_tool('use_sampling', {'foo': 'bar'}) - + await server.direct_call_tool('use_sampling', {'foo': 'bar'}) -async def test_mcp_server_raises_mcp_error(allow_model_requests: None, agent: Agent) -> None: - server = agent._mcp_servers[0] # pyright: ignore[reportPrivateUsage] +async def test_mcp_server_raises_mcp_error( + allow_model_requests: None, mcp_server: MCPServerStdio, agent: Agent, run_context: RunContext[int] +) -> None: mcp_error = McpError(error=ErrorData(code=400, message='Test MCP error conversion')) - async with agent.run_mcp_servers(): + async with agent: with patch.object( - server._client, # pyright: ignore[reportPrivateUsage] + mcp_server._client, # pyright: ignore[reportPrivateUsage] 'send_request', new=AsyncMock(side_effect=mcp_error), ): with pytest.raises(ModelRetry, match='Test MCP error conversion'): - await server.call_tool('test_tool', {}) + await mcp_server.direct_call_tool('test_tool', {}) def test_map_from_mcp_params_model_request(): diff --git a/tests/test_streaming.py b/tests/test_streaming.py index c7753f1fad..dbdcd71f32 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -5,6 +5,7 @@ import re from collections.abc import AsyncIterator from copy import deepcopy +from dataclasses import replace from datetime import timezone from typing import Any, Union @@ -12,14 +13,16 @@ from inline_snapshot import snapshot from pydantic import BaseModel -from pydantic_ai import Agent, UnexpectedModelBehavior, UserError, capture_run_messages +from pydantic_ai import Agent, RunContext, UnexpectedModelBehavior, UserError, capture_run_messages from pydantic_ai.agent import AgentRun from pydantic_ai.messages import ( + FinalResultEvent, FunctionToolCallEvent, FunctionToolResultEvent, ModelMessage, ModelRequest, ModelResponse, + PartStartEvent, RetryPromptPart, TextPart, ToolCallPart, @@ -28,8 +31,9 @@ ) from pydantic_ai.models.function import AgentInfo, DeltaToolCall, DeltaToolCalls, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import PromptedOutput, TextOutput +from pydantic_ai.output import DeferredToolCalls, PromptedOutput, TextOutput from pydantic_ai.result import AgentStream, FinalResult, Usage +from pydantic_ai.tools import ToolDefinition from pydantic_graph import End from .conftest import IsInt, IsNow, IsStr @@ -272,7 +276,7 @@ async def text_stream(_messages: list[ModelMessage], _: AgentInfo) -> AsyncItera agent = Agent(FunctionModel(stream_function=text_stream), output_type=tuple[str, str]) - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(1\) for output validation'): async with agent.run_stream(''): pass @@ -407,7 +411,7 @@ async def ret_a(x: str) -> str: # pragma: no cover return x with capture_run_messages() as messages: - with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for result validation'): + with pytest.raises(UnexpectedModelBehavior, match=r'Exceeded maximum retries \(0\) for output validation'): async with agent.run_stream('hello'): pass @@ -613,18 +617,18 @@ def another_tool(y: int) -> int: timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr(), ), - RetryPromptPart( - tool_name='unknown_tool', - content="Unknown tool name: 'unknown_tool'. Available tools: regular_tool, another_tool, final_result", - timestamp=IsNow(tz=timezone.utc), - tool_call_id=IsStr(), - ), ToolReturnPart( tool_name='regular_tool', content=42, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), ToolReturnPart( tool_name='another_tool', content=2, timestamp=IsNow(tz=timezone.utc), tool_call_id=IsStr() ), + RetryPromptPart( + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", + tool_name='unknown_tool', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), ] ), ] @@ -712,15 +716,15 @@ def another_tool(y: int) -> int: # pragma: no cover ModelRequest( parts=[ ToolReturnPart( - tool_name='regular_tool', - content='Tool not executed - a final result was already processed.', + tool_name='final_result', + content='Final result processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', ), ToolReturnPart( - tool_name='final_result', - content='Final result processed.', + tool_name='regular_tool', + content='Tool not executed - a final result was already processed.', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), part_kind='tool-return', @@ -733,10 +737,7 @@ def another_tool(y: int) -> int: # pragma: no cover part_kind='tool-return', ), RetryPromptPart( - content='Unknown tool name: ' - "'unknown_tool'. Available tools: " - 'regular_tool, another_tool, ' - 'final_result', + content="Unknown tool name: 'unknown_tool'. Available tools: 'final_result', 'regular_tool', 'another_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=datetime.timezone.utc), @@ -975,6 +976,13 @@ def known_tool(x: int) -> int: assert event_parts == snapshot( [ + FunctionToolCallEvent( + part=ToolCallPart( + tool_name='known_tool', + args={'x': 5}, + tool_call_id=IsStr(), + ) + ), FunctionToolCallEvent( part=ToolCallPart( tool_name='unknown_tool', @@ -984,14 +992,11 @@ def known_tool(x: int) -> int: ), FunctionToolResultEvent( result=RetryPromptPart( - content="Unknown tool name: 'unknown_tool'. Available tools: known_tool", + content="Unknown tool name: 'unknown_tool'. Available tools: 'known_tool'", tool_name='unknown_tool', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), - ) - ), - FunctionToolCallEvent( - part=ToolCallPart(tool_name='known_tool', args={'x': 5}, tool_call_id=IsStr()), + ), ), FunctionToolResultEvent( result=ToolReturnPart( @@ -999,13 +1004,6 @@ def known_tool(x: int) -> int: content=10, tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), - ) - ), - FunctionToolCallEvent( - part=ToolCallPart( - tool_name='unknown_tool', - args={'arg': 'value'}, - tool_call_id=IsStr(), ), ), ] @@ -1027,15 +1025,15 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf agent = Agent(FunctionModel(call_final_result_with_bad_data), output_type=OutputType) - event_parts: list[Any] = [] + events: list[Any] = [] async with agent.iter('test') as agent_run: async for node in agent_run: if Agent.is_call_tools_node(node): async with node.stream(agent_run.ctx) as event_stream: async for event in event_stream: - event_parts.append(event) + events.append(event) - assert event_parts == snapshot( + assert events == snapshot( [ FunctionToolCallEvent( part=ToolCallPart( @@ -1045,9 +1043,16 @@ def call_final_result_with_bad_data(messages: list[ModelMessage], info: AgentInf ), ), FunctionToolResultEvent( - result=ToolReturnPart( + result=RetryPromptPart( + content=[ + { + 'type': 'missing', + 'loc': ('value',), + 'msg': 'Field required', + 'input': {'bad_value': 'invalid'}, + } + ], tool_name='final_result', - content='Output tool not used - result failed validation.', tool_call_id=IsStr(), timestamp=IsNow(tz=timezone.utc), ) @@ -1118,3 +1123,93 @@ def test_function_tool_event_tool_call_id_properties(): # The event should expose the same `tool_call_id` as the result part assert result_event.tool_call_id == return_part.tool_call_id == 'return_id_456' + + +async def test_deferred_tool(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 # pragma: no cover + + async with agent.run_stream('Hello') as result: + assert not result.is_complete + output = await result.get_output() + assert output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ) + assert result.is_complete + + +async def test_deferred_tool_iter(): + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls]) + + async def prepare_tool(ctx: RunContext[None], tool_def: ToolDefinition) -> ToolDefinition: + return replace(tool_def, kind='deferred') + + @agent.tool_plain(prepare=prepare_tool) + def my_tool(x: int) -> int: + return x + 1 # pragma: no cover + + outputs: list[str | DeferredToolCalls] = [] + events: list[Any] = [] + + async with agent.iter('test') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + async for output in stream.stream_output(debounce_by=None): + outputs.append(output) + if agent.is_call_tools_node(node): + async with node.stream(run.ctx) as stream: + async for event in stream: + events.append(event) + + assert outputs == snapshot( + [ + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + kind='deferred', + ) + }, + ) + ] + ) + assert events == snapshot( + [ + PartStartEvent( + index=0, + part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr()), + ), + FinalResultEvent(tool_name=None, tool_call_id=None), + FunctionToolCallEvent(part=ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())), + ] + ) diff --git a/tests/test_tools.py b/tests/test_tools.py index 00f3f5bcc0..c72cd1e086 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -12,11 +12,17 @@ from typing_extensions import TypedDict from pydantic_ai import Agent, RunContext, Tool, UserError +from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel -from pydantic_ai.output import ToolOutput +from pydantic_ai.output import DeferredToolCalls, ToolOutput from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.deferred import DeferredToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset + +from .conftest import IsStr def test_tool_no_ctx(): @@ -105,6 +111,7 @@ def test_docstring_google(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -136,6 +143,7 @@ def test_docstring_sphinx(docstring_format: Literal['sphinx', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -175,6 +183,7 @@ def test_docstring_numpy(docstring_format: Literal['numpy', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -214,6 +223,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -251,6 +261,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -294,6 +305,7 @@ def my_tool(x: int) -> str: # pragma: no cover }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -325,6 +337,7 @@ def test_only_returns_type(): 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -347,6 +360,7 @@ def test_docstring_unknown(): 'parameters_json_schema': {'properties': {}, 'type': 'object'}, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -387,6 +401,7 @@ def test_docstring_google_no_body(docstring_format: Literal['google', 'auto']): }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -420,6 +435,7 @@ def takes_just_model(model: Foo) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -462,6 +478,7 @@ def takes_just_model(model: Foo, z: int) -> str: }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -481,15 +498,15 @@ def plain_tool(x: int) -> int: result = agent.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0]) - assert agent._function_tools['plain_tool'].takes_ctx is False - assert agent._function_tools['plain_tool'].max_retries == 7 + assert agent._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent._function_toolset.tools['plain_tool'].max_retries == 7 agent_infer = Agent('test', tools=[plain_tool], retries=7) result = agent_infer.run_sync('foobar') assert result.output == snapshot('{"plain_tool":1}') assert call_args == snapshot([0, 0]) - assert agent_infer._function_tools['plain_tool'].takes_ctx is False - assert agent_infer._function_tools['plain_tool'].max_retries == 7 + assert agent_infer._function_toolset.tools['plain_tool'].takes_ctx is False + assert agent_infer._function_toolset.tools['plain_tool'].max_retries == 7 def ctx_tool(ctx: RunContext[int], x: int) -> int: @@ -501,13 +518,13 @@ def test_init_tool_ctx(): agent = Agent('test', tools=[Tool(ctx_tool, takes_ctx=True, max_retries=3)], deps_type=int, retries=7) result = agent.run_sync('foobar', deps=5) assert result.output == snapshot('{"ctx_tool":5}') - assert agent._function_tools['ctx_tool'].takes_ctx is True - assert agent._function_tools['ctx_tool'].max_retries == 3 + assert agent._function_toolset.tools['ctx_tool'].takes_ctx is True + assert agent._function_toolset.tools['ctx_tool'].max_retries == 3 agent_infer = Agent('test', tools=[ctx_tool], deps_type=int) result = agent_infer.run_sync('foobar', deps=6) assert result.output == snapshot('{"ctx_tool":6}') - assert agent_infer._function_tools['ctx_tool'].takes_ctx is True + assert agent_infer._function_toolset.tools['ctx_tool'].takes_ctx is True def test_repeat_tool_by_rename(): @@ -557,18 +574,40 @@ def foo(x: int, y: str) -> str: # pragma: no cover def bar(x: int, y: str) -> str: # pragma: no cover return f'{x} {y}' - with pytest.raises(UserError, match=r"Tool name conflicts with existing tool: 'bar'."): + with pytest.raises(UserError, match="Tool name conflicts with previously renamed tool: 'bar'."): agent.run_sync('') def test_tool_return_conflict(): # this is okay - Agent('test', tools=[ctx_tool], deps_type=int) + Agent('test', tools=[ctx_tool], deps_type=int).run_sync('', deps=0) # this is also okay - Agent('test', tools=[ctx_tool], deps_type=int, output_type=int) + Agent('test', tools=[ctx_tool], deps_type=int, output_type=int).run_sync('', deps=0) # this raises an error - with pytest.raises(UserError, match="Tool name conflicts with output tool name: 'ctx_tool'"): - Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')) + with pytest.raises( + UserError, + match="Function toolset defines a tool whose name conflicts with existing tool from Output toolset: 'ctx_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + ): + Agent('test', tools=[ctx_tool], deps_type=int, output_type=ToolOutput(int, name='ctx_tool')).run_sync( + '', deps=0 + ) + + +def test_tool_name_conflict_hint(): + with pytest.raises( + UserError, + match="Prefixed toolset defines a tool whose name conflicts with existing tool from Function toolset: 'foo_tool'. Rename the tool or wrap the toolset in a `PrefixedToolset` to avoid name conflicts.", + ): + + def tool(x: int) -> int: + return x + 1 # pragma: no cover + + def foo_tool(x: str) -> str: + return x + 'foo' # pragma: no cover + + function_toolset = FunctionToolset([tool]) + prefixed_toolset = PrefixedToolset(function_toolset, 'foo') + Agent('test', tools=[foo_tool], toolsets=[prefixed_toolset]).run_sync('') def test_init_ctx_tool_invalid(): @@ -798,6 +837,7 @@ def test_suppress_griffe_logging(caplog: LogCaptureFixture): 'outer_typed_dict_key': None, 'parameters_json_schema': {'additionalProperties': False, 'properties': {}, 'type': 'object'}, 'strict': None, + 'kind': 'function', } ) @@ -867,6 +907,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': None, @@ -879,6 +920,7 @@ def my_tool_plain(*, a: int = 1, b: int) -> int: 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -963,6 +1005,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, { 'description': None, @@ -973,6 +1016,7 @@ def my_tool(x: Annotated[Union[str, None], WithJsonSchema({'type': 'string'})] = 'type': 'object', }, 'strict': None, + 'kind': 'function', }, ] ) @@ -1008,6 +1052,7 @@ def get_score(data: Data) -> int: ... # pragma: no branch }, 'outer_typed_dict_key': None, 'strict': None, + 'kind': 'function', } ) @@ -1039,7 +1084,7 @@ def foobar(ctx: RunContext[int], x: int, y: str) -> str: with agent.override(model=FunctionModel(get_json_schema)): result = agent.run_sync('', deps=21) json_schema = json.loads(result.output) - assert agent._function_tools['foobar'].strict is None + assert agent._function_toolset.tools['foobar'].strict is None assert json_schema['strict'] is True result = agent.run_sync('', deps=1) @@ -1066,8 +1111,8 @@ def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 def test_function_tool_inconsistent_with_schema(): @@ -1113,5 +1158,146 @@ async def function(*args: Any, **kwargs: Any) -> str: agent = Agent('test', tools=[pydantic_tool], retries=0) result = agent.run_sync('foobar') assert result.output == snapshot('{"foobar":"I like being called like this"}') - assert agent._function_tools['foobar'].takes_ctx is False - assert agent._function_tools['foobar'].max_retries == 0 + assert agent._function_toolset.tools['foobar'].takes_ctx is False + assert agent._function_toolset.tools['foobar'].max_retries == 0 + + +def test_tool_retries(): + prepare_tools_retries: list[int] = [] + prepare_retries: list[int] = [] + call_retries: list[int] = [] + + async def prepare_tool_defs( + ctx: RunContext[None], tool_defs: list[ToolDefinition] + ) -> Union[list[ToolDefinition], None]: + nonlocal prepare_tools_retries + retry = ctx.retries.get('infinite_retry_tool', 0) + prepare_tools_retries.append(retry) + return tool_defs + + agent = Agent(TestModel(), retries=3, prepare_tools=prepare_tool_defs) + + async def prepare_tool_def(ctx: RunContext[None], tool_def: ToolDefinition) -> Union[ToolDefinition, None]: + nonlocal prepare_retries + prepare_retries.append(ctx.retry) + return tool_def + + @agent.tool(retries=5, prepare=prepare_tool_def) + def infinite_retry_tool(ctx: RunContext[None]) -> int: + nonlocal call_retries + call_retries.append(ctx.retry) + raise ModelRetry('Please try again.') + + with pytest.raises(UnexpectedModelBehavior, match="Tool 'infinite_retry_tool' exceeded max retries count of 5"): + agent.run_sync('Begin infinite retry loop!') + + # There are extra 0s here because the toolset is prepared once ahead of the graph run, before the user prompt part is added in. + assert prepare_tools_retries == [0, 0, 1, 2, 3, 4, 5] + assert prepare_retries == [0, 0, 1, 2, 3, 4, 5] + assert call_retries == [0, 1, 2, 3, 4, 5] + + +def test_deferred_tool(): + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), output_type=[str, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot( + DeferredToolCalls( + tool_calls=[ToolCallPart(tool_name='my_tool', args={'x': 0}, tool_call_id=IsStr())], + tool_defs={ + 'my_tool': ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={ + 'type': 'object', + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + }, + kind='deferred', + ) + }, + ) + ) + + +def test_deferred_tool_with_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(call_tools=[]), output_type=[MyModel, DeferredToolCalls], toolsets=[deferred_toolset]) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +def test_deferred_tool_with_tool_output_type(): + class MyModel(BaseModel): + foo: str + + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent( + TestModel(call_tools=[]), + output_type=[[ToolOutput(MyModel), ToolOutput(MyModel)], DeferredToolCalls], + toolsets=[deferred_toolset], + ) + + result = agent.run_sync('Hello') + assert result.output == snapshot(MyModel(foo='a')) + + +async def test_deferred_tool_without_output_type(): + deferred_toolset = DeferredToolset( + [ + ToolDefinition( + name='my_tool', + description='', + parameters_json_schema={'type': 'object', 'properties': {'x': {'type': 'integer'}}, 'required': ['x']}, + ), + ] + ) + agent = Agent(TestModel(), toolsets=[deferred_toolset]) + + msg = 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' + + with pytest.raises(UserError, match=msg): + await agent.run('Hello') + + with pytest.raises(UserError, match=msg): + async with agent.run_stream('Hello') as result: + await result.get_output() + + +def test_output_type_deferred_tool_calls_by_itself(): + with pytest.raises(UserError, match='At least one output type must be provided other than `DeferredToolCalls`.'): + Agent(TestModel(), output_type=DeferredToolCalls) + + +def test_output_type_empty(): + with pytest.raises(UserError, match='At least one output type must be provided.'): + Agent(TestModel(), output_type=[]) diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py new file mode 100644 index 0000000000..eac0dc78a7 --- /dev/null +++ b/tests/test_toolsets.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +import re +from dataclasses import dataclass, replace +from typing import TypeVar + +import pytest +from inline_snapshot import snapshot + +from pydantic_ai._run_context import RunContext +from pydantic_ai._tool_manager import ToolManager +from pydantic_ai.exceptions import UserError +from pydantic_ai.messages import ToolCallPart +from pydantic_ai.models.test import TestModel +from pydantic_ai.tools import ToolDefinition +from pydantic_ai.toolsets.combined import CombinedToolset +from pydantic_ai.toolsets.filtered import FilteredToolset +from pydantic_ai.toolsets.function import FunctionToolset +from pydantic_ai.toolsets.prefixed import PrefixedToolset +from pydantic_ai.toolsets.prepared import PreparedToolset +from pydantic_ai.usage import Usage + +pytestmark = pytest.mark.anyio + +T = TypeVar('T') + + +def build_run_context(deps: T) -> RunContext[T]: + return RunContext( + deps=deps, + model=TestModel(), + usage=Usage(), + prompt=None, + messages=[], + run_step=0, + ) + + +async def test_function_toolset(): + @dataclass + class PrefixDeps: + prefix: str | None = None + + toolset = FunctionToolset[PrefixDeps]() + + async def prepare_add_prefix(ctx: RunContext[PrefixDeps], tool_def: ToolDefinition) -> ToolDefinition | None: + if ctx.deps.prefix is None: + return tool_def + + return replace(tool_def, name=f'{ctx.deps.prefix}_{tool_def.name}') + + @toolset.tool(prepare=prepare_add_prefix) + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + no_prefix_context = build_run_context(PrefixDeps()) + no_prefix_toolset = await ToolManager[PrefixDeps].build(toolset, no_prefix_context) + assert no_prefix_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='add', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + description='Add two numbers', + ) + ] + ) + assert await no_prefix_toolset.handle_call(ToolCallPart(tool_name='add', args={'a': 1, 'b': 2})) == 3 + + foo_context = build_run_context(PrefixDeps(prefix='foo')) + foo_toolset = await ToolManager[PrefixDeps].build(toolset, foo_context) + assert foo_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='foo_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ) + ] + ) + assert await foo_toolset.handle_call(ToolCallPart(tool_name='foo_add', args={'a': 1, 'b': 2})) == 3 + + @toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: lax no cover + + bar_context = build_run_context(PrefixDeps(prefix='bar')) + bar_toolset = await ToolManager[PrefixDeps].build(toolset, bar_context) + assert bar_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='bar_add', + description='Add two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='subtract', + description='Subtract two numbers', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + assert await bar_toolset.handle_call(ToolCallPart(tool_name='bar_add', args={'a': 1, 'b': 2})) == 3 + + +async def test_prepared_toolset_user_error_add_new_tools(): + """Test that PreparedToolset raises UserError when prepare function tries to add new tools.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b # pragma: no cover + + async def prepare_add_new_tool(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Try to add a new tool that wasn't in the original set + new_tool = ToolDefinition( + name='new_tool', + description='A new tool', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'x': {'type': 'integer'}}, + 'required': ['x'], + 'type': 'object', + }, + ) + return tool_defs + [new_tool] + + prepared_toolset = PreparedToolset(base_toolset, prepare_add_new_tool) + + with pytest.raises( + UserError, + match=re.escape( + 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' + ), + ): + await ToolManager[None].build(prepared_toolset, context) + + +async def test_prepared_toolset_user_error_change_tool_names(): + """Test that PreparedToolset raises UserError when prepare function tries to change tool names.""" + context = build_run_context(None) + base_toolset = FunctionToolset[None]() + + @base_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b # pragma: no cover + + @base_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: no cover + + async def prepare_change_names(ctx: RunContext[None], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Try to change the name of an existing tool + modified_tool_defs: list[ToolDefinition] = [] + for tool_def in tool_defs: + if tool_def.name == 'add': + modified_tool_defs.append(replace(tool_def, name='modified_add')) + else: + modified_tool_defs.append(tool_def) + return modified_tool_defs + + prepared_toolset = PreparedToolset(base_toolset, prepare_change_names) + + with pytest.raises( + UserError, + match=re.escape( + 'Prepare function cannot add or rename tools. Use `FunctionToolset.add_function()` or `RenamedToolset` instead.' + ), + ): + await ToolManager[None].build(prepared_toolset, context) + + +async def test_comprehensive_toolset_composition(): + """Test that all toolsets can be composed together and work correctly.""" + + @dataclass + class TestDeps: + user_role: str = 'user' + enable_advanced: bool = True + + # Create first FunctionToolset with basic math operations + math_toolset = FunctionToolset[TestDeps]() + + @math_toolset.tool + def add(a: int, b: int) -> int: + """Add two numbers""" + return a + b + + @math_toolset.tool + def subtract(a: int, b: int) -> int: + """Subtract two numbers""" + return a - b # pragma: no cover + + @math_toolset.tool + def multiply(a: int, b: int) -> int: + """Multiply two numbers""" + return a * b # pragma: no cover + + # Create second FunctionToolset with string operations + string_toolset = FunctionToolset[TestDeps]() + + @string_toolset.tool + def concat(s1: str, s2: str) -> str: + """Concatenate two strings""" + return s1 + s2 + + @string_toolset.tool + def uppercase(text: str) -> str: + """Convert text to uppercase""" + return text.upper() # pragma: no cover + + @string_toolset.tool + def reverse(text: str) -> str: + """Reverse a string""" + return text[::-1] # pragma: no cover + + # Create third FunctionToolset with advanced operations + advanced_toolset = FunctionToolset[TestDeps]() + + @advanced_toolset.tool + def power(base: int, exponent: int) -> int: + """Calculate base raised to the power of exponent""" + return base**exponent # pragma: no cover + + # Step 1: Prefix each FunctionToolset individually + prefixed_math = PrefixedToolset(math_toolset, 'math') + prefixed_string = PrefixedToolset(string_toolset, 'str') + prefixed_advanced = PrefixedToolset(advanced_toolset, 'adv') + + # Step 2: Combine the prefixed toolsets + combined_prefixed_toolset = CombinedToolset([prefixed_math, prefixed_string, prefixed_advanced]) + + # Step 3: Filter tools based on user role and advanced flag, now using prefixed names + def filter_tools(ctx: RunContext[TestDeps], tool_def: ToolDefinition) -> bool: + # Only allow advanced tools if enable_advanced is True + if tool_def.name.startswith('adv_') and not ctx.deps.enable_advanced: + return False + # Only allow string operations for admin users (simulating role-based access) + if tool_def.name.startswith('str_') and ctx.deps.user_role != 'admin': + return False + return True + + filtered_toolset = FilteredToolset[TestDeps](combined_prefixed_toolset, filter_tools) + + # Step 4: Apply prepared toolset to modify descriptions (add user role annotation) + async def prepare_add_context(ctx: RunContext[TestDeps], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + # Annotate each tool description with the user role + role = ctx.deps.user_role + return [replace(td, description=f'{td.description} (role: {role})') for td in tool_defs] + + prepared_toolset = PreparedToolset(filtered_toolset, prepare_add_context) + + # Step 5: Test the fully composed toolset + # Test with regular user context + regular_deps = TestDeps(user_role='user', enable_advanced=True) + regular_context = build_run_context(regular_deps) + final_toolset = await ToolManager[TestDeps].build(prepared_toolset, regular_context) + # Tool definitions should have role annotation + assert final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_power', + description='Calculate base raised to the power of exponent (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'base': {'type': 'integer'}, 'exponent': {'type': 'integer'}}, + 'required': ['base', 'exponent'], + 'type': 'object', + }, + ), + ] + ) + # Call a tool and check result + result = await final_toolset.handle_call(ToolCallPart(tool_name='math_add', args={'a': 5, 'b': 3})) + assert result == 8 + + # Test with admin user context (should have string tools) + admin_deps = TestDeps(user_role='admin', enable_advanced=True) + admin_context = build_run_context(admin_deps) + admin_final_toolset = await ToolManager[TestDeps].build(prepared_toolset, admin_context) + assert admin_final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_concat', + description='Concatenate two strings (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'s1': {'type': 'string'}, 's2': {'type': 'string'}}, + 'required': ['s1', 's2'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_uppercase', + description='Convert text to uppercase (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'text': {'type': 'string'}}, + 'required': ['text'], + 'type': 'object', + }, + ), + ToolDefinition( + name='str_reverse', + description='Reverse a string (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'text': {'type': 'string'}}, + 'required': ['text'], + 'type': 'object', + }, + ), + ToolDefinition( + name='adv_power', + description='Calculate base raised to the power of exponent (role: admin)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'base': {'type': 'integer'}, 'exponent': {'type': 'integer'}}, + 'required': ['base', 'exponent'], + 'type': 'object', + }, + ), + ] + ) + result = await admin_final_toolset.handle_call( + ToolCallPart(tool_name='str_concat', args={'s1': 'Hello', 's2': 'World'}) + ) + assert result == 'HelloWorld' + + # Test with advanced features disabled + basic_deps = TestDeps(user_role='user', enable_advanced=False) + basic_context = build_run_context(basic_deps) + basic_final_toolset = await ToolManager[TestDeps].build(prepared_toolset, basic_context) + assert basic_final_toolset.tool_defs == snapshot( + [ + ToolDefinition( + name='math_add', + description='Add two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_subtract', + description='Subtract two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ToolDefinition( + name='math_multiply', + description='Multiply two numbers (role: user)', + parameters_json_schema={ + 'additionalProperties': False, + 'properties': {'a': {'type': 'integer'}, 'b': {'type': 'integer'}}, + 'required': ['a', 'b'], + 'type': 'object', + }, + ), + ] + ) + + +async def test_context_manager(): + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + toolset = CombinedToolset([server1, PrefixedToolset(server2, 'prefix')]) + + async with toolset: + assert server1.is_running + assert server2.is_running + + async with toolset: + assert server1.is_running + assert server2.is_running diff --git a/tests/typed_agent.py b/tests/typed_agent.py index 6ea4c4c223..3e1171c076 100644 --- a/tests/typed_agent.py +++ b/tests/typed_agent.py @@ -10,7 +10,7 @@ from pydantic_ai import Agent, ModelRetry, RunContext, Tool from pydantic_ai.agent import AgentRunResult -from pydantic_ai.output import StructuredDict, TextOutput, ToolOutput +from pydantic_ai.output import DeferredToolCalls, StructuredDict, TextOutput, ToolOutput from pydantic_ai.tools import ToolDefinition # Define here so we can check `if MYPY` below. This will not be executed, MYPY will always set it to True @@ -222,6 +222,14 @@ def my_method(self) -> bool: assert_type( complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) + + complex_deferred_output_agent = Agent[ + None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls + ](output_type=[complex_output_agent.output_type, DeferredToolCalls]) + assert_type( + complex_deferred_output_agent, + Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls], + ) else: # pyright is able to correctly infer the type here async_int_function_agent = Agent(output_type=foobar_plain) @@ -241,6 +249,12 @@ def my_method(self) -> bool: complex_output_agent, Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str]] ) + complex_deferred_output_agent = Agent(output_type=[complex_output_agent.output_type, DeferredToolCalls]) + assert_type( + complex_deferred_output_agent, + Agent[None, Foo | Bar | Decimal | int | bool | tuple[str, int] | str | re.Pattern[str] | DeferredToolCalls], + ) + Tool(foobar_ctx, takes_ctx=True) Tool(foobar_ctx) From 64b64a50a22eb7afe85edfa2494faa289345f635 Mon Sep 17 00:00:00 2001 From: naveen-corpusant <131562836+naveen-corpusant@users.noreply.github.com> Date: Wed, 16 Jul 2025 12:11:03 -0700 Subject: [PATCH 21/89] [openai] Remove incorrect tool call id from tool call delta (#2210) Co-authored-by: Douwe Maan --- pydantic_ai_slim/pydantic_ai/models/openai.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index b968ac61fd..92d79c6340 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1051,7 +1051,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: vendor_part_id=chunk.item_id, tool_name=None, args=chunk.delta, - tool_call_id=chunk.item_id, + tool_call_id=None, ) if maybe_event is not None: # pragma: no branch yield maybe_event From 4eff63cdfe1c8493c932b6870f36165d8e978747 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 16 Jul 2025 14:09:30 -0600 Subject: [PATCH 22/89] Remove old Google models (#2220) --- .../pydantic_ai_examples/stream_markdown.py | 2 +- .../pydantic_ai/models/__init__.py | 20 ++----------------- pydantic_ai_slim/pydantic_ai/models/gemini.py | 10 +--------- pydantic_ai_slim/pydantic_ai/models/google.py | 10 +--------- 4 files changed, 5 insertions(+), 37 deletions(-) diff --git a/examples/pydantic_ai_examples/stream_markdown.py b/examples/pydantic_ai_examples/stream_markdown.py index 53f61737b4..cf9335cd51 100644 --- a/examples/pydantic_ai_examples/stream_markdown.py +++ b/examples/pydantic_ai_examples/stream_markdown.py @@ -26,7 +26,7 @@ # models to try, and the appropriate env var models: list[tuple[KnownModelName, str]] = [ - ('google-gla:gemini-1.5-flash', 'GEMINI_API_KEY'), + ('google-gla:gemini-2.0-flash', 'GEMINI_API_KEY'), ('openai:gpt-4o-mini', 'OPENAI_API_KEY'), ('groq:llama-3.3-70b-versatile', 'GROQ_API_KEY'), ] diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 811c128379..11ec50f85b 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -134,31 +134,15 @@ 'cohere:command-r7b-12-2024', 'deepseek:deepseek-chat', 'deepseek:deepseek-reasoner', - 'google-gla:gemini-1.5-flash', - 'google-gla:gemini-1.5-flash-8b', - 'google-gla:gemini-1.5-pro', - 'google-gla:gemini-1.0-pro', 'google-gla:gemini-2.0-flash', - 'google-gla:gemini-2.0-flash-lite-preview-02-05', - 'google-gla:gemini-2.0-pro-exp-02-05', - 'google-gla:gemini-2.5-flash-preview-05-20', + 'google-gla:gemini-2.0-flash-lite', 'google-gla:gemini-2.5-flash', 'google-gla:gemini-2.5-flash-lite-preview-06-17', - 'google-gla:gemini-2.5-pro-exp-03-25', - 'google-gla:gemini-2.5-pro-preview-05-06', 'google-gla:gemini-2.5-pro', - 'google-vertex:gemini-1.5-flash', - 'google-vertex:gemini-1.5-flash-8b', - 'google-vertex:gemini-1.5-pro', - 'google-vertex:gemini-1.0-pro', 'google-vertex:gemini-2.0-flash', - 'google-vertex:gemini-2.0-flash-lite-preview-02-05', - 'google-vertex:gemini-2.0-pro-exp-02-05', - 'google-vertex:gemini-2.5-flash-preview-05-20', + 'google-vertex:gemini-2.0-flash-lite', 'google-vertex:gemini-2.5-flash', 'google-vertex:gemini-2.5-flash-lite-preview-06-17', - 'google-vertex:gemini-2.5-pro-exp-03-25', - 'google-vertex:gemini-2.5-pro-preview-05-06', 'google-vertex:gemini-2.5-pro', 'gpt-3.5-turbo', 'gpt-3.5-turbo-0125', diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 99aa99a301..b5e58c43ab 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -48,18 +48,10 @@ ) LatestGeminiModelNames = Literal[ - 'gemini-1.5-flash', - 'gemini-1.5-flash-8b', - 'gemini-1.5-pro', - 'gemini-1.0-pro', 'gemini-2.0-flash', - 'gemini-2.0-flash-lite-preview-02-05', - 'gemini-2.0-pro-exp-02-05', - 'gemini-2.5-flash-preview-05-20', + 'gemini-2.0-flash-lite', 'gemini-2.5-flash', 'gemini-2.5-flash-lite-preview-06-17', - 'gemini-2.5-pro-exp-03-25', - 'gemini-2.5-pro-preview-05-06', 'gemini-2.5-pro', ] """Latest Gemini models.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 3755cc16e8..8b9af7ba2f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -73,18 +73,10 @@ ) from _import_error LatestGoogleModelNames = Literal[ - 'gemini-1.5-flash', - 'gemini-1.5-flash-8b', - 'gemini-1.5-pro', - 'gemini-1.0-pro', 'gemini-2.0-flash', - 'gemini-2.0-flash-lite-preview-02-05', - 'gemini-2.0-pro-exp-02-05', - 'gemini-2.5-flash-preview-05-20', + 'gemini-2.0-flash-lite', 'gemini-2.5-flash', 'gemini-2.5-flash-lite-preview-06-17', - 'gemini-2.5-pro-exp-03-25', - 'gemini-2.5-pro-preview-05-06', 'gemini-2.5-pro', ] """Latest Gemini models.""" From 4193208b684683189612c4b4530a11dd36fa4324 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Wed, 16 Jul 2025 14:42:47 -0600 Subject: [PATCH 23/89] Add Douwe to maintainers (#2221) --- clai/pyproject.toml | 9 +++++---- examples/pyproject.toml | 8 +++++++- pydantic_ai_slim/pyproject.toml | 9 +++++---- pydantic_evals/pyproject.toml | 4 ++++ pydantic_graph/pyproject.toml | 8 +++++++- pyproject.toml | 1 + 6 files changed, 29 insertions(+), 10 deletions(-) diff --git a/clai/pyproject.toml b/clai/pyproject.toml index 6b62863d47..af578c8552 100644 --- a/clai/pyproject.toml +++ b/clai/pyproject.toml @@ -15,10 +15,11 @@ name = "clai" dynamic = ["version", "dependencies"] description = "PydanticAI CLI: command line interface to chat to LLMs" authors = [ - { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, - { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, - { name = "David Montague", email = "david@pydantic.dev" }, - { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, + { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, + { name = "David Montague", email = "david@pydantic.dev" }, + { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, ] license = "MIT" readme = "README.md" diff --git a/examples/pyproject.toml b/examples/pyproject.toml index bb5dcd9ef9..770c9ba3b2 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -14,7 +14,13 @@ bump = true name = "pydantic-ai-examples" dynamic = ["version", "dependencies"] description = "Examples of how to use PydanticAI and what it can do." -authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }] +authors = [ + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, + { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, + { name = "David Montague", email = "david@pydantic.dev" }, + { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, +] license = "MIT" readme = "README.md" classifiers = [ diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 2705ca9144..4b62e40d98 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -15,10 +15,11 @@ name = "pydantic-ai-slim" dynamic = ["version", "dependencies", "optional-dependencies"] description = "Agent Framework / shim to use Pydantic with LLMs, slim package" authors = [ - { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, - { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, - { name = "David Montague", email = "david@pydantic.dev" }, - { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, + { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, + { name = "David Montague", email = "david@pydantic.dev" }, + { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, ] license = "MIT" readme = "README.md" diff --git a/pydantic_evals/pyproject.toml b/pydantic_evals/pyproject.toml index c471c8738c..5b51cf2a18 100644 --- a/pydantic_evals/pyproject.toml +++ b/pydantic_evals/pyproject.toml @@ -15,7 +15,11 @@ name = "pydantic-evals" dynamic = ["version", "dependencies"] description = "Framework for evaluating stochastic code execution, especially code making use of LLMs" authors = [ + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, + { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, { name = "David Montague", email = "david@pydantic.dev" }, + { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, ] license = "MIT" readme = "README.md" diff --git a/pydantic_graph/pyproject.toml b/pydantic_graph/pyproject.toml index 82723fcc0b..a24d819c53 100644 --- a/pydantic_graph/pyproject.toml +++ b/pydantic_graph/pyproject.toml @@ -14,7 +14,13 @@ bump = true name = "pydantic-graph" dynamic = ["version"] description = "Graph and state machine library" -authors = [{ name = "Samuel Colvin", email = "samuel@pydantic.dev" }] +authors = [ + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, + { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, + { name = "David Montague", email = "david@pydantic.dev" }, + { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, +] license = "MIT" readme = "README.md" classifiers = [ diff --git a/pyproject.toml b/pyproject.toml index 534f156db7..b2b82aa867 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -19,6 +19,7 @@ authors = [ { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, { name = "David Montague", email = "david@pydantic.dev" }, { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, ] license = "MIT" readme = "README.md" From 4a475965856734e2debb6c6580369228b7067a1e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?c=C3=A9lina?= Date: Thu, 17 Jul 2025 09:03:35 +0100 Subject: [PATCH 24/89] [Docs] List Hugging Face Inference Provider (#2229) --- docs/models/index.md | 1 + 1 file changed, 1 insertion(+) diff --git a/docs/models/index.md b/docs/models/index.md index 67069f2a4a..c384d30a7d 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -9,6 +9,7 @@ PydanticAI is model-agnostic and has built-in support for multiple model provide * [Mistral](mistral.md) * [Cohere](cohere.md) * [Bedrock](bedrock.md) +* [Hugging Face](huggingface.md) ## OpenAI-compatible Providers From 479b346c813868766b8a7ca8d866ca9be141bdbd Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Thu, 17 Jul 2025 01:20:58 -0700 Subject: [PATCH 25/89] change clai default model to gpt-4.1 (#2227) --- clai/README.md | 2 +- docs/cli.md | 8 ++++---- pydantic_ai_slim/pydantic_ai/_cli.py | 8 +++++--- 3 files changed, 10 insertions(+), 8 deletions(-) diff --git a/clai/README.md b/clai/README.md index 8899a82ffe..6b7de53790 100644 --- a/clai/README.md +++ b/clai/README.md @@ -68,7 +68,7 @@ positional arguments: options: -h, --help show this help message and exit -m [MODEL], --model [MODEL] - Model to use, in format ":" e.g. "openai:gpt-4o" or "anthropic:claude-3-7-sonnet-latest". Defaults to "openai:gpt-4o". + Model to use, in format ":" e.g. "openai:gpt-4.1" or "anthropic:claude-sonnet-4-0". Defaults to "openai:gpt-4.1". -a AGENT, --agent AGENT Custom Agent to use, in format "module:variable", e.g. "mymodule.submodule:my_agent" -l, --list-models List all available models and exit diff --git a/docs/cli.md b/docs/cli.md index 2cd373c000..01b49b9bad 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -60,7 +60,7 @@ uvx clai --help You can specify which model to use with the `--model` flag: ```bash -uvx clai --model anthropic:claude-3-7-sonnet-latest +uvx clai --model anthropic:claude-sonnet-4-0 ``` (a full list of models available can be printed with `uvx clai --list-models`) @@ -72,7 +72,7 @@ You can specify a custom agent using the `--agent` flag with a module path and v ```python {title="custom_agent.py" test="skip"} from pydantic_ai import Agent -agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') +agent = Agent('openai:gpt-4.1', instructions='You always respond in Italian.') ``` Then run: @@ -92,7 +92,7 @@ Additionally, you can directly launch CLI mode from an `Agent` instance using `A ```python {title="agent_to_cli_sync.py" test="skip" hl_lines=4} from pydantic_ai import Agent -agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') +agent = Agent('openai:gpt-4.1', instructions='You always respond in Italian.') agent.to_cli_sync() ``` @@ -101,7 +101,7 @@ You can also use the async interface with `Agent.to_cli()`: ```python {title="agent_to_cli.py" test="skip" hl_lines=6} from pydantic_ai import Agent -agent = Agent('openai:gpt-4o', instructions='You always respond in Italian.') +agent = Agent('openai:gpt-4.1', instructions='You always respond in Italian.') async def main(): await agent.to_cli() diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index 894f630c11..c4e3d63fca 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -101,7 +101,9 @@ def cli_exit(prog_name: str = 'pai'): # pragma: no cover sys.exit(cli(prog_name=prog_name)) -def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> int: # noqa: C901 +def cli( # noqa: C901 + args_list: Sequence[str] | None = None, *, prog_name: str = 'pai', default_model: str = 'openai:gpt-4.1' +) -> int: """Run the CLI and return the exit code for the process.""" parser = argparse.ArgumentParser( prog=prog_name, @@ -120,7 +122,7 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in '-m', '--model', nargs='?', - help='Model to use, in format ":" e.g. "openai:gpt-4o" or "anthropic:claude-3-7-sonnet-latest". Defaults to "openai:gpt-4o".', + help=f'Model to use, in format ":" e.g. "openai:gpt-4.1" or "anthropic:claude-sonnet-4-0". Defaults to "{default_model}".', ) # we don't want to autocomplete or list models that don't include the provider, # e.g. we want to show `openai:gpt-4o` but not `gpt-4o` @@ -179,7 +181,7 @@ def cli(args_list: Sequence[str] | None = None, *, prog_name: str = 'pai') -> in model_arg_set = args.model is not None if agent.model is None or model_arg_set: try: - agent.model = infer_model(args.model or 'openai:gpt-4o') + agent.model = infer_model(args.model or default_model) except UserError as e: console.print(f'Error initializing [magenta]{args.model}[/magenta]:\n[red]{e}[/red]') return 1 From 420166d0f1f5dbe4db16298ae6b81b443982166f Mon Sep 17 00:00:00 2001 From: Binal Patel Date: Thu, 17 Jul 2025 01:44:18 -0700 Subject: [PATCH 26/89] fix: allow empty content for TextPart on GoogleModel (#2203) --- pydantic_ai_slim/pydantic_ai/models/google.py | 3 +- ...google_model_empty_assistant_response.yaml | 60 +++++++++++++++++++ tests/models/test_google.py | 15 +++++ 3 files changed, 76 insertions(+), 2 deletions(-) create mode 100644 tests/models/cassettes/test_google/test_google_model_empty_assistant_response.yaml diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index 8b9af7ba2f..9ec1260d4e 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -484,8 +484,7 @@ def _content_model_response(m: ModelResponse) -> ContentDict: function_call = FunctionCallDict(name=item.tool_name, args=item.args_as_dict(), id=item.tool_call_id) parts.append({'function_call': function_call}) elif isinstance(item, TextPart): - if item.content: # pragma: no branch - parts.append({'text': item.content}) + parts.append({'text': item.content}) elif isinstance(item, ThinkingPart): # pragma: no cover # NOTE: We don't send ThinkingPart to the providers yet. If you are unsatisfied with this, # please open an issue. The below code is the code to send thinking to the provider. diff --git a/tests/models/cassettes/test_google/test_google_model_empty_assistant_response.yaml b/tests/models/cassettes/test_google/test_google_model_empty_assistant_response.yaml new file mode 100644 index 0000000000..53ec4054a6 --- /dev/null +++ b/tests/models/cassettes/test_google/test_google_model_empty_assistant_response.yaml @@ -0,0 +1,60 @@ +interactions: +- request: + headers: + content-type: + - application/json + method: post + parsed_body: + contents: + - parts: + - text: Hi + role: user + - parts: + - text: '' + role: model + - parts: + - text: Empty? + role: user + generationConfig: {} + uri: https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent + response: + headers: + alt-svc: + - h3=":443"; ma=2592000,h3-29=":443"; ma=2592000 + content-length: + - '724' + content-type: + - application/json; charset=UTF-8 + server-timing: + - gfet4t7; dur=387 + transfer-encoding: + - chunked + vary: + - Origin + - X-Origin + - Referer + parsed_body: + candidates: + - avgLogprobs: -0.02701732359434429 + content: + parts: + - text: | + Yes, your previous message was empty. Is there anything I can help you with? + role: model + finishReason: STOP + modelVersion: gemini-1.5-flash + responseId: NHt4aPycNfzcnvgP1q6EgQw + usageMetadata: + candidatesTokenCount: 19 + candidatesTokensDetails: + - modality: TEXT + tokenCount: 19 + promptTokenCount: 3 + promptTokensDetails: + - modality: TEXT + tokenCount: 3 + totalTokenCount: 22 + status: + code: 200 + message: OK +version: 1 diff --git a/tests/models/test_google.py b/tests/models/test_google.py index 493b2c9b5e..7e1f372bcc 100644 --- a/tests/models/test_google.py +++ b/tests/models/test_google.py @@ -604,6 +604,21 @@ async def test_google_model_empty_user_prompt(allow_model_requests: None, google assert result.output == snapshot("I'm ready to assist you. Please tell me what you need.\n") +async def test_google_model_empty_assistant_response(allow_model_requests: None, google_provider: GoogleProvider): + m = GoogleModel('gemini-1.5-flash', provider=google_provider) + agent = Agent(m) + + result = await agent.run( + 'Empty?', + message_history=[ + ModelRequest(parts=[UserPromptPart(content='Hi')]), + ModelResponse(parts=[TextPart(content='')]), + ], + ) + + assert result.output == snapshot('Yes, your previous message was empty. Is there anything I can help you with?\n') + + async def test_google_model_thinking_part(allow_model_requests: None, google_provider: GoogleProvider): m = GoogleModel('gemini-2.5-pro-preview-03-25', provider=google_provider) settings = GoogleModelSettings(google_thinking_config={'include_thoughts': True}) From 7bd03fc8d51137b02988eb4dfe26ecb42b40f94b Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Thu, 17 Jul 2025 12:22:05 +0200 Subject: [PATCH 27/89] chore: add experimetal GitHub agent (#2232) --- .github/workflows/agent.yml | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) create mode 100644 .github/workflows/agent.yml diff --git a/.github/workflows/agent.yml b/.github/workflows/agent.yml new file mode 100644 index 0000000000..60706f7b5f --- /dev/null +++ b/.github/workflows/agent.yml @@ -0,0 +1,19 @@ +name: GitHub Agent + +on: + issues: + types: [opened] + +jobs: + agent: + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + - name: View the github context + run: echo "$GITHUB_CONTEXT" + env: + GITHUB_CONTEXT: ${{ toJson(github) }} + - name: Run GitHub Agent + uses: kludex/github-agent@main + env: + ANTHROPIC_API_KEY: ${{ secrets.ANTHROPIC_API_KEY }} From 883e1ea25f7b8a6645e3bfc0e715f3e4b79d4631 Mon Sep 17 00:00:00 2001 From: Herman Semykozov Date: Thu, 17 Jul 2025 15:34:09 +0300 Subject: [PATCH 28/89] update logos, favicon and brand names (#2193) --- .github/workflows/ci.yml | 4 +- README.md | 62 +- clai/README.md | 4 +- clai/pyproject.toml | 2 +- clai/update_readme.py | 2 +- docs/.overrides/.icons/logfire/logo.svg | 8 +- docs/.partials/index-header.html | 21 +- docs/a2a.md | 17 +- docs/agents.md | 25 +- docs/api/models/test.md | 2 +- docs/changelog.md | 11 +- docs/cli.md | 5 +- docs/common-tools.md | 2 +- docs/contributing.md | 20 +- docs/dependencies.md | 6 +- docs/evals.md | 2 +- docs/examples/bank-support.md | 8 +- docs/examples/index.md | 2 +- docs/examples/pydantic-model.md | 4 +- docs/examples/rag.md | 10 +- docs/examples/slack-lead-qualifier.md | 1 - docs/examples/sql-gen.md | 11 +- docs/examples/weather-agent.md | 16 +- docs/favicon.ico | Bin 4286 -> 4286 bytes docs/graph.md | 286 +++++--- docs/help.md | 6 +- docs/img/logo-white.svg | 18 +- docs/img/pydantic-ai-dark.svg | 36 +- docs/img/pydantic-ai-light.svg | 36 +- docs/index.md | 65 +- docs/input.md | 13 +- docs/install.md | 8 +- docs/logfire.md | 39 +- docs/mcp/client.md | 10 +- docs/mcp/index.md | 18 +- docs/mcp/run-python.md | 4 +- docs/mcp/server.md | 8 +- docs/message-history.md | 13 +- docs/models/gemini.md | 6 +- docs/models/index.md | 74 +- docs/models/openai.md | 6 +- docs/multi-agent-applications.md | 4 +- docs/output.md | 8 +- docs/testing.md | 14 +- docs/thinking.md | 4 +- docs/tools.md | 24 +- docs/troubleshooting.md | 6 +- examples/README.md | 4 +- examples/pydantic_ai_examples/bank_support.py | 2 +- .../pydantic_ai_examples/pydantic_model.py | 2 +- .../pydantic_ai_examples/roulette_wheel.py | 2 +- examples/pydantic_ai_examples/sql_gen.py | 2 +- .../pydantic_ai_examples/weather_agent.py | 2 +- examples/pyproject.toml | 2 +- mcp-run-python/README.md | 2 +- mkdocs.yml | 36 +- pydantic_ai_slim/README.md | 4 +- pydantic_ai_slim/pydantic_ai/_cli.py | 8 +- .../pydantic_ai/_parts_manager.py | 4 +- pydantic_ai_slim/pydantic_ai/messages.py | 10 +- .../pydantic_ai/models/anthropic.py | 2 +- .../pydantic_ai/profiles/google.py | 2 +- pydantic_ai_slim/pydantic_ai/usage.py | 2 +- pydantic_evals/README.md | 10 +- pydantic_evals/pydantic_evals/generation.py | 4 +- pydantic_graph/README.md | 6 +- .../test_tool_returning_text_resource.yaml | 676 +++++++++--------- tests/example_modules/mcp_server.py | 2 +- tests/mcp_server.py | 4 +- .../test_download_item_no_content_type.yaml | 16 +- tests/test_cli.py | 4 +- tests/test_mcp.py | 6 +- 72 files changed, 880 insertions(+), 885 deletions(-) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index ca6620e03d..db7ea8430f 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -207,7 +207,7 @@ jobs: - run: rm coverage/.coverage.*-py3.9-* # Exclude 3.9 coverage as it gets the wrong line numbers, causing invalid failures. - run: uv run coverage combine coverage - - run: uv run coverage html --show-contexts --title "PydanticAI coverage for ${{ github.sha }}" + - run: uv run coverage html --show-contexts --title "Pydantic AI coverage for ${{ github.sha }}" - name: Store coverage html uses: actions/upload-artifact@v4 @@ -417,7 +417,7 @@ jobs: env: VERSION: ${{ needs.release.outputs.package-version }} TWEET: | - PydanticAI version {version} is out! 🎉 + Pydantic AI version {version} is out! 🎉 https://github.com/pydantic/pydantic-ai/releases/tag/v{version} TWITTER_CONSUMER_KEY: ${{ secrets.TWITTER_CONSUMER_KEY }} diff --git a/README.md b/README.md index a7430b7bd2..8514f027e8 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ - PydanticAI + Pydantic AI @@ -24,47 +24,47 @@ --- -PydanticAI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI. +Pydantic AI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI. -FastAPI revolutionized web development by offering an innovative and ergonomic design, built on the foundation of [Pydantic](https://docs.pydantic.dev). +FastAPI revolutionized web development by offering an innovative and ergonomic design, built on the foundation of [Pydantic Validation](https://docs.pydantic.dev). -Similarly, virtually every agent framework and LLM library in Python uses Pydantic, yet when we began to use LLMs in [Pydantic Logfire](https://pydantic.dev/logfire), we couldn't find anything that gave us the same feeling. +Similarly, virtually every agent framework and LLM library in Python uses Pydantic Validation, yet when we began to use LLMs in [Pydantic Logfire](https://pydantic.dev/logfire), we couldn't find anything that gave us the same feeling. -We built PydanticAI with one simple aim: to bring that FastAPI feeling to GenAI app development. +We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI app development. -## Why use PydanticAI +## Why use Pydantic AI -* __Built by the Pydantic Team__ -Built by the team behind [Pydantic](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more). +- **Built by the Pydantic Team** + Built by the team behind [Pydantic Validation](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more). -* __Model-agnostic__ -Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/). +- **Model-agnostic** + Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](https://ai.pydantic.dev/models/). -* __Pydantic Logfire Integration__ -Seamlessly [integrates](https://ai.pydantic.dev/logfire/) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications. +- **Pydantic Logfire Integration** + Seamlessly [integrates](https://ai.pydantic.dev/logfire/) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications. -* __Type-safe__ -Designed to make [type checking](https://ai.pydantic.dev/agents/#static-type-checking) as powerful and informative as possible for you. +- **Type-safe** + Designed to make [type checking](https://ai.pydantic.dev/agents/#static-type-checking) as powerful and informative as possible for you. -* __Python-centric Design__ -Leverages Python's familiar control flow and agent composition to build your AI-driven projects, making it easy to apply standard Python best practices you'd use in any other (non-AI) project. +- **Python-centric Design** + Leverages Python's familiar control flow and agent composition to build your AI-driven projects, making it easy to apply standard Python best practices you'd use in any other (non-AI) project. -* __Structured Responses__ -Harnesses the power of [Pydantic](https://docs.pydantic.dev/latest/) to [validate and structure](https://ai.pydantic.dev/output/#structured-output) model outputs, ensuring responses are consistent across runs. +- **Structured Responses** + Harnesses the power of [Pydantic Validation](https://docs.pydantic.dev/latest/) to [validate and structure](https://ai.pydantic.dev/output/#structured-output) model outputs, ensuring responses are consistent across runs. -* __Dependency Injection System__ -Offers an optional [dependency injection](https://ai.pydantic.dev/dependencies/) system to provide data and services to your agent's [system prompts](https://ai.pydantic.dev/agents/#system-prompts), [tools](https://ai.pydantic.dev/tools/) and [output validators](https://ai.pydantic.dev/output/#output-validator-functions). -This is useful for testing and eval-driven iterative development. +- **Dependency Injection System** + Offers an optional [dependency injection](https://ai.pydantic.dev/dependencies/) system to provide data and services to your agent's [system prompts](https://ai.pydantic.dev/agents/#system-prompts), [tools](https://ai.pydantic.dev/tools/) and [output validators](https://ai.pydantic.dev/output/#output-validator-functions). + This is useful for testing and eval-driven iterative development. -* __Streamed Responses__ -Provides the ability to [stream](https://ai.pydantic.dev/output/#streamed-results) LLM outputs continuously, with immediate validation, ensuring rapid and accurate outputs. +- **Streamed Responses** + Provides the ability to [stream](https://ai.pydantic.dev/output/#streamed-results) LLM outputs continuously, with immediate validation, ensuring rapid and accurate outputs. -* __Graph Support__ -[Pydantic Graph](https://ai.pydantic.dev/graph) provides a powerful way to define graphs using typing hints, this is useful in complex applications where standard control flow can degrade to spaghetti code. +- **Graph Support** + [Pydantic Graph](https://ai.pydantic.dev/graph) provides a powerful way to define graphs using typing hints, this is useful in complex applications where standard control flow can degrade to spaghetti code. ## Hello World Example -Here's a minimal example of PydanticAI: +Here's a minimal example of Pydantic AI: ```python from pydantic_ai import Agent @@ -78,7 +78,7 @@ agent = Agent( ) # Run the agent synchronously, conducting a conversation with the LLM. -# Here the exchange should be very short: PydanticAI will send the system prompt and the user query to the LLM, +# Here the exchange should be very short: Pydantic AI will send the system prompt and the user query to the LLM, # the model will return a text response. See below for a more complex run. result = agent.run_sync('Where does "hello world" come from?') print(result.output) @@ -93,7 +93,7 @@ Not very interesting yet, but we can easily add "tools", dynamic system prompts, ## Tools & Dependency Injection Example -Here is a concise example using PydanticAI to build a support agent for a bank: +Here is a concise example using Pydantic AI to build a support agent for a bank: **(Better documented example [in the docs](https://ai.pydantic.dev/#tools-dependency-injection-example))** @@ -187,8 +187,8 @@ async def main(): ## Next Steps -To try PydanticAI yourself, follow the instructions [in the examples](https://ai.pydantic.dev/examples/). +To try Pydantic AI yourself, follow the instructions [in the examples](https://ai.pydantic.dev/examples/). -Read the [docs](https://ai.pydantic.dev/agents/) to learn more about building applications with PydanticAI. +Read the [docs](https://ai.pydantic.dev/agents/) to learn more about building applications with Pydantic AI. -Read the [API Reference](https://ai.pydantic.dev/api/agent/) to understand PydanticAI's interface. +Read the [API Reference](https://ai.pydantic.dev/api/agent/) to understand Pydantic AI's interface. diff --git a/clai/README.md b/clai/README.md index 6b7de53790..635e381fc0 100644 --- a/clai/README.md +++ b/clai/README.md @@ -8,7 +8,7 @@ (pronounced "clay") -Command line interface to chat to LLMs, part of the [PydanticAI project](https://github.com/pydantic/pydantic-ai). +Command line interface to chat to LLMs, part of the [Pydantic AI project](https://github.com/pydantic/pydantic-ai). ## Usage @@ -55,7 +55,7 @@ Either way, running `clai` will start an interactive session where you can chat ``` usage: clai [-h] [-m [MODEL]] [-a AGENT] [-l] [-t [CODE_THEME]] [--no-stream] [--version] [prompt] -PydanticAI CLI v... +Pydantic AI CLI v... Special prompts: * `/exit` - exit the interactive mode (ctrl-c and ctrl-d also work) diff --git a/clai/pyproject.toml b/clai/pyproject.toml index af578c8552..411e616f06 100644 --- a/clai/pyproject.toml +++ b/clai/pyproject.toml @@ -13,7 +13,7 @@ bump = true [project] name = "clai" dynamic = ["version", "dependencies"] -description = "PydanticAI CLI: command line interface to chat to LLMs" +description = "Pydantic AI CLI: command line interface to chat to LLMs" authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, diff --git a/clai/update_readme.py b/clai/update_readme.py index 1d88d51c4a..859ef69f3f 100644 --- a/clai/update_readme.py +++ b/clai/update_readme.py @@ -17,7 +17,7 @@ def test_cli_help(capfd: pytest.CaptureFixture[str]): help_output = capfd.readouterr().out.strip() # TODO change when we reach v1 - help_output = re.sub(r'(PydanticAI CLI v).+', r'\1...', help_output) + help_output = re.sub(r'(Pydantic AI CLI v).+', r'\1...', help_output) this_dir = Path(__file__).parent readme = this_dir / 'README.md' diff --git a/docs/.overrides/.icons/logfire/logo.svg b/docs/.overrides/.icons/logfire/logo.svg index 60ce5aa4b4..a88cecadd9 100644 --- a/docs/.overrides/.icons/logfire/logo.svg +++ b/docs/.overrides/.icons/logfire/logo.svg @@ -1,4 +1,6 @@ - - - + + + + + diff --git a/docs/.partials/index-header.html b/docs/.partials/index-header.html index cdb95a9378..e6ad634be1 100644 --- a/docs/.partials/index-header.html +++ b/docs/.partials/index-header.html @@ -1,18 +1,25 @@
- PydanticAI + Pydantic AI
- PydanticAI + Pydantic AI

Agent Framework / shim to use Pydantic with LLMs

- CI + CI - Coverage + Coverage PyPI @@ -24,11 +31,11 @@ license - Join Slack + Join Slack

- PydanticAI is a Python agent framework designed to make it less painful to - build production grade applications with Generative AI. + Pydantic AI is a Python agent framework designed to make it less painful to build production grade + applications with Generative AI.

diff --git a/docs/a2a.md b/docs/a2a.md index df5cd6efb1..b465ed860d 100644 --- a/docs/a2a.md +++ b/docs/a2a.md @@ -5,7 +5,7 @@ communication and interoperability between AI agents, regardless of the framewor At Pydantic, we built the [FastA2A](#fasta2a) library to make it easier to implement the A2A protocol in Python. -We also built a convenience method that expose PydanticAI agents as A2A servers - let's have a quick look at how to use it: +We also built a convenience method that expose Pydantic AI agents as A2A servers - let's have a quick look at how to use it: ```py {title="agent_to_a2a.py" hl_lines="4"} from pydantic_ai import Agent @@ -18,12 +18,12 @@ _You can run the example with `uvicorn agent_to_a2a:app --host 0.0.0.0 --port 80 This will expose the agent as an A2A server, and you can start sending requests to it. -See more about [exposing PydanticAI agents as A2A servers](#pydanticai-agent-to-a2a-server). +See more about [exposing Pydantic AI agents as A2A servers](#pydantic-ai-agent-to-a2a-server). ## FastA2A **FastA2A** is an agentic framework agnostic implementation of the A2A protocol in Python. -The library is designed to be used with any agentic framework, and is **not exclusive to PydanticAI**. +The library is designed to be used with any agentic framework, and is **not exclusive to Pydantic AI**. ### Design @@ -75,8 +75,7 @@ The [`Storage`][fasta2a.Storage] component serves two purposes: This design allows for agents to store rich internal state (e.g., tool calls, reasoning traces) as well as store task-specific A2A-formatted messages and artifacts. -For example, a PydanticAI agent might store its complete internal message format (including tool calls and responses) in the context storage, while storing only the A2A-compliant messages in the task history. - +For example, a Pydantic AI agent might store its complete internal message format (including tool calls and responses) in the context storage, while storing only the A2A-compliant messages in the task history. ### Installation @@ -92,15 +91,15 @@ The only dependencies are: - [pydantic](https://pydantic.dev): to validate the request/response messages - [opentelemetry-api](https://opentelemetry-python.readthedocs.io/en/latest): to provide tracing capabilities -You can install PydanticAI with the `a2a` extra to include **FastA2A**: +You can install Pydantic AI with the `a2a` extra to include **FastA2A**: ```bash pip/uv-add 'pydantic-ai-slim[a2a]' ``` -### PydanticAI Agent to A2A Server +### Pydantic AI Agent to A2A Server -To expose a PydanticAI agent as an A2A server, you can use the `to_a2a` method: +To expose a Pydantic AI agent as an A2A server, you can use the `to_a2a` method: ```python {title="agent_to_a2a.py"} from pydantic_ai import Agent @@ -117,7 +116,7 @@ uvicorn agent_to_a2a:app --host 0.0.0.0 --port 8000 Since the goal of `to_a2a` is to be a convenience method, it accepts the same arguments as the [`FastA2A`][fasta2a.FastA2A] constructor. -When using `to_a2a()`, PydanticAI automatically: +When using `to_a2a()`, Pydantic AI automatically: - Stores the complete conversation history (including tool calls and responses) in the context storage - Ensures that subsequent messages with the same `context_id` have access to the full conversation history diff --git a/docs/agents.md b/docs/agents.md index 5b332e0b27..91e2602e3e 100644 --- a/docs/agents.md +++ b/docs/agents.md @@ -1,6 +1,6 @@ ## Introduction -Agents are PydanticAI's primary interface for interacting with LLMs. +Agents are Pydantic AI's primary interface for interacting with LLMs. In some use cases a single Agent will control an entire application or component, but multiple agents can also interact to embody more complex workflows. @@ -8,7 +8,7 @@ but multiple agents can also interact to embody more complex workflows. The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptually you can think of an agent as a container for: | **Component** | **Description** | -|-----------------------------------------------|-----------------------------------------------------------------------------------------------------------| +| --------------------------------------------- | --------------------------------------------------------------------------------------------------------- | | [System prompt(s)](#system-prompts) | A set of instructions for the LLM written by the developer. | | [Function tool(s)](tools.md) | Functions that the LLM may call to get information while generating a response. | | [Structured output type](output.md) | The structured datatype the LLM must return at the end of a run, if specified. | @@ -16,7 +16,7 @@ The [`Agent`][pydantic_ai.Agent] class has full API documentation, but conceptua | [LLM model](api/models/base.md) | Optional default LLM model associated with the agent. Can also be specified when running the agent. | | [Model Settings](#additional-configuration) | Optional default model settings to help fine tune requests. Can also be specified when running the agent. | -In typing terms, agents are generic in their dependency and output types, e.g., an agent which required dependencies of type `#!python Foobar` and produced outputs of type `#!python list[str]` would have type `Agent[Foobar, list[str]]`. In practice, you shouldn't need to care about this, it should just mean your IDE can tell you when you have the right type, and if you choose to use [static type checking](#static-type-checking) it should work well with PydanticAI. +In typing terms, agents are generic in their dependency and output types, e.g., an agent which required dependencies of type `#!python Foobar` and produced outputs of type `#!python list[str]` would have type `Agent[Foobar, list[str]]`. In practice, you shouldn't need to care about this, it should just mean your IDE can tell you when you have the right type, and if you choose to use [static type checking](#static-type-checking) it should work well with Pydantic AI. Here's a toy example of an agent that simulates a roulette wheel: @@ -56,7 +56,6 @@ print(result.output) 3. In reality, you might want to use a random number here e.g. `random.randint(0, 36)`. 4. `result.output` will be a boolean indicating if the square is a winner. Pydantic performs the output validation, and it'll be typed as a `bool` since its type is derived from the `output_type` generic parameter of the agent. - !!! tip "Agents are designed for reuse, like FastAPI Apps" Agents are intended to be instantiated once (frequently as module globals) and reused throughout your application, similar to a small [FastAPI][fastapi.FastAPI] app or an [APIRouter][fastapi.APIRouter]. @@ -90,16 +89,16 @@ async def main(): print(await response.get_output()) #> London ``` + _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ You can also pass messages from previous runs to continue a conversation or provide context, as described in [Messages and Chat History](message-history.md). - ### Iterating Over an Agent's Graph -Under the hood, each `Agent` in PydanticAI uses **pydantic-graph** to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on PydanticAI — you can use it standalone for workflows that have nothing to do with GenAI — but PydanticAI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. +Under the hood, each `Agent` in Pydantic AI uses **pydantic-graph** to manage its execution flow. **pydantic-graph** is a generic, type-centric library for building and running finite state machines in Python. It doesn't actually depend on Pydantic AI — you can use it standalone for workflows that have nothing to do with GenAI — but Pydantic AI makes use of it to orchestrate the handling of model requests and model responses in an agent's run. -In many scenarios, you don't need to worry about pydantic-graph at all; calling `agent.run(...)` simply traverses the underlying graph from start to finish. However, if you need deeper insight or control — for example to capture each tool invocation, or to inject your own logic at specific stages — PydanticAI exposes the lower-level iteration process via [`Agent.iter`][pydantic_ai.Agent.iter]. This method returns an [`AgentRun`][pydantic_ai.agent.AgentRun], which you can async-iterate over, or manually drive node-by-node via the [`next`][pydantic_ai.agent.AgentRun.next] method. Once the agent's graph returns an [`End`][pydantic_graph.nodes.End], you have the final result along with a detailed history of all steps. +In many scenarios, you don't need to worry about pydantic-graph at all; calling `agent.run(...)` simply traverses the underlying graph from start to finish. However, if you need deeper insight or control — for example to capture each tool invocation, or to inject your own logic at specific stages — Pydantic AI exposes the lower-level iteration process via [`Agent.iter`][pydantic_ai.Agent.iter]. This method returns an [`AgentRun`][pydantic_ai.agent.AgentRun], which you can async-iterate over, or manually drive node-by-node via the [`next`][pydantic_ai.agent.AgentRun.next] method. Once the agent's graph returns an [`End`][pydantic_graph.nodes.End], you have the final result along with a detailed history of all steps. #### `async for` iteration @@ -222,7 +221,7 @@ async def main(): 1. We start by grabbing the first node that will be run in the agent's graph. 2. The agent run is finished once an `End` node has been produced; instances of `End` cannot be passed to `next`. -3. When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the *next* node to run. +3. When you call `await agent_run.next(node)`, it executes that node in the agent's graph, updates the run's history, and returns the _next_ node to run. 4. You could also inspect or mutate the new `node` here as needed. #### Accessing usage and the final output @@ -381,7 +380,7 @@ if __name__ == '__main__': #### Usage Limits -PydanticAI offers a [`UsageLimits`][pydantic_ai.usage.UsageLimits] structure to help you limit your +Pydantic AI offers a [`UsageLimits`][pydantic_ai.usage.UsageLimits] structure to help you limit your usage (tokens and/or requests) on model runs. You can apply these settings by passing the `usage_limits` argument to the `run{_sync,_stream}` functions. @@ -462,7 +461,7 @@ except UsageLimitExceeded as e: #### Model (Run) Settings -PydanticAI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests. +Pydantic AI offers a [`settings.ModelSettings`][pydantic_ai.settings.ModelSettings] structure to help you fine tune your requests. This structure allows you to configure common parameters that influence the model's behavior, such as `temperature`, `max_tokens`, `timeout`, and more. @@ -573,12 +572,12 @@ _(This example is complete, it can be run "as is")_ ## Type safe by design {#static-type-checking} -PydanticAI is designed to work well with static type checkers, like mypy and pyright. +Pydantic AI is designed to work well with static type checkers, like mypy and pyright. !!! tip "Typing is (somewhat) optional" - PydanticAI is designed to make type checking as useful as possible for you if you choose to use it, but you don't have to use types everywhere all the time. + Pydantic AI is designed to make type checking as useful as possible for you if you choose to use it, but you don't have to use types everywhere all the time. - That said, because PydanticAI uses Pydantic, and Pydantic uses type hints as the definition for schema and validation, some types (specifically type hints on parameters to tools, and the `output_type` arguments to [`Agent`][pydantic_ai.Agent]) are used at runtime. + That said, because Pydantic AI uses Pydantic, and Pydantic uses type hints as the definition for schema and validation, some types (specifically type hints on parameters to tools, and the `output_type` arguments to [`Agent`][pydantic_ai.Agent]) are used at runtime. We (the library developers) have messed up if type hints are confusing you more than helping you, if you find this, please create an [issue](https://github.com/pydantic/pydantic-ai/issues) explaining what's annoying you! diff --git a/docs/api/models/test.md b/docs/api/models/test.md index cdbdcb33ec..120acc4c21 100644 --- a/docs/api/models/test.md +++ b/docs/api/models/test.md @@ -1,6 +1,6 @@ # `pydantic_ai.models.test` -Utility model for quickly testing apps built with PydanticAI. +Utility model for quickly testing apps built with Pydantic AI. Here's a minimal example: diff --git a/docs/changelog.md b/docs/changelog.md index 421eac77be..571e4c8fd8 100644 --- a/docs/changelog.md +++ b/docs/changelog.md @@ -1,6 +1,6 @@ # Upgrade Guide -PydanticAI is still pre-version 1, so breaking changes will occur, however: +Pydantic AI is still pre-version 1, so breaking changes will occur, however: - We try to minimize them as much as possible. - We use minor version bumps to signify breaking changes. @@ -10,7 +10,7 @@ PydanticAI is still pre-version 1, so breaking changes will occur, however: ## Breaking Changes !!! note - Here's a filtered list of the breaking changes for each version to help you upgrade PydanticAI. + Here's a filtered list of the breaking changes for each version to help you upgrade Pydantic AI. ### v0.4.0 (2025-07-08) @@ -23,7 +23,7 @@ See [#1507](https://github.com/pydantic/pydantic-ai/pull/1507) - The `ToolDefini See [#1142](https://github.com/pydantic/pydantic-ai/pull/1142) — Adds support for thinking parts. We now convert the thinking blocks (`"...""`) in provider specific text parts to -PydanticAI `ThinkingPart`s. Also, as part of this release, we made the choice to not send back the +Pydantic AI `ThinkingPart`s. Also, as part of this release, we made the choice to not send back the `ThinkingPart`s to the provider - the idea is to save costs on behalf of the user. In the future, we intend to add a setting to customize this behavior. @@ -31,9 +31,8 @@ intend to add a setting to customize this behavior. See [#1647](https://github.com/pydantic/pydantic-ai/pull/1647) — usage makes sense as part of `ModelResponse`, and could be really useful in "messages" (really a sequence of requests and response). In this PR: -* Adds `usage` to `ModelResponse` (field has a default factory of `Usage()` so it'll work to load data that doesn't have usage) -* changes the return type of `Model.request` to just `ModelResponse` instead of `tuple[ModelResponse, Usage]` - +- Adds `usage` to `ModelResponse` (field has a default factory of `Usage()` so it'll work to load data that doesn't have usage) +- changes the return type of `Model.request` to just `ModelResponse` instead of `tuple[ModelResponse, Usage]` ### v0.1.0 (2025-04-15) diff --git a/docs/cli.md b/docs/cli.md index 01b49b9bad..9f5b94d961 100644 --- a/docs/cli.md +++ b/docs/cli.md @@ -1,9 +1,9 @@ # Command Line Interface (CLI) -**PydanticAI** comes with a CLI, `clai` (pronounced "clay") which you can use to interact with various LLMs from the command line. +**Pydantic AI** comes with a CLI, `clai` (pronounced "clay") which you can use to interact with various LLMs from the command line. It provides a convenient way to chat with language models and quickly get answers right in the terminal. -We originally developed this CLI for our own use, but found ourselves using it so frequently that we decided to share it as part of the PydanticAI package. +We originally developed this CLI for our own use, but found ourselves using it so frequently that we decided to share it as part of the Pydantic AI package. We plan to continue adding new features, such as interaction with MCP servers, access to tools, and more. @@ -86,7 +86,6 @@ The format must be `module:variable` where: - `module` is the importable Python module path - `variable` is the name of the Agent instance in that module - Additionally, you can directly launch CLI mode from an `Agent` instance using `Agent.to_cli_sync()`: ```python {title="agent_to_cli_sync.py" test="skip" hl_lines=4} diff --git a/docs/common-tools.md b/docs/common-tools.md index 3cb4196184..08964e458d 100644 --- a/docs/common-tools.md +++ b/docs/common-tools.md @@ -1,6 +1,6 @@ # Common Tools -PydanticAI ships with native tools that can be used to enhance your agent's capabilities. +Pydantic AI ships with native tools that can be used to enhance your agent's capabilities. ## DuckDuckGo Search Tool diff --git a/docs/contributing.md b/docs/contributing.md index 71b55ef4b4..b376712895 100644 --- a/docs/contributing.md +++ b/docs/contributing.md @@ -1,4 +1,4 @@ -We'd love you to contribute to PydanticAI! +We'd love you to contribute to Pydantic AI! ## Installation and Setup @@ -11,9 +11,9 @@ cd pydantic-ai Install `uv` (version 0.4.30 or later), `pre-commit` and `deno`: -* [`uv` install docs](https://docs.astral.sh/uv/getting-started/installation/) -* [`pre-commit` install docs](https://pre-commit.com/#install) -* [`deno` install docs](https://docs.deno.com/runtime/getting_started/installation/) +- [`uv` install docs](https://docs.astral.sh/uv/getting-started/installation/) +- [`pre-commit` install docs](https://pre-commit.com/#install) +- [`deno` install docs](https://docs.deno.com/runtime/getting_started/installation/) To install `pre-commit` you can run the following command: @@ -59,13 +59,13 @@ To run the documentation page locally, run: uv run mkdocs serve ``` -## Rules for adding new models to PydanticAI {#new-model-rules} +## Rules for adding new models to Pydantic AI {#new-model-rules} -To avoid an excessive workload for the maintainers of PydanticAI, we can't accept all model contributions, so we're setting the following rules for when we'll accept new models and when we won't. This should hopefully reduce the chances of disappointment and wasted work. +To avoid an excessive workload for the maintainers of Pydantic AI, we can't accept all model contributions, so we're setting the following rules for when we'll accept new models and when we won't. This should hopefully reduce the chances of disappointment and wasted work. -* To add a new model with an extra dependency, that dependency needs > 500k monthly downloads from PyPI consistently over 3 months or more -* To add a new model which uses another models logic internally and has no extra dependencies, that model's GitHub org needs > 20k stars in total -* For any other model that's just a custom URL and API key, we're happy to add a one-paragraph description with a link and instructions on the URL to use -* For any other model that requires more logic, we recommend you release your own Python package `pydantic-ai-xxx`, which depends on [`pydantic-ai-slim`](install.md#slim-install) and implements a model that inherits from our [`Model`][pydantic_ai.models.Model] ABC +- To add a new model with an extra dependency, that dependency needs > 500k monthly downloads from PyPI consistently over 3 months or more +- To add a new model which uses another models logic internally and has no extra dependencies, that model's GitHub org needs > 20k stars in total +- For any other model that's just a custom URL and API key, we're happy to add a one-paragraph description with a link and instructions on the URL to use +- For any other model that requires more logic, we recommend you release your own Python package `pydantic-ai-xxx`, which depends on [`pydantic-ai-slim`](install.md#slim-install) and implements a model that inherits from our [`Model`][pydantic_ai.models.Model] ABC If you're unsure about adding a model, please [create an issue](https://github.com/pydantic/pydantic-ai/issues). diff --git a/docs/dependencies.md b/docs/dependencies.md index b3cd44ae86..c5e3874f12 100644 --- a/docs/dependencies.md +++ b/docs/dependencies.md @@ -1,8 +1,8 @@ # Dependencies -PydanticAI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [tools](tools.md) and [output validators](output.md#output-validator-functions). +Pydantic AI uses a dependency injection system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [tools](tools.md) and [output validators](output.md#output-validator-functions). -Matching PydanticAI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. +Matching Pydantic AI's design philosophy, our dependency system tries to use existing best practice in Python development rather than inventing esoteric "magic", this should make dependencies type-safe, understandable easier to test and ultimately easier to deploy in production. ## Defining Dependencies @@ -299,7 +299,7 @@ async def test_application_code(): ## Examples -The following examples demonstrate how to use dependencies in PydanticAI: +The following examples demonstrate how to use dependencies in Pydantic AI: - [Weather Agent](examples/weather-agent.md) - [SQL Generation](examples/sql-gen.md) diff --git a/docs/evals.md b/docs/evals.md index a8348f20d2..e1820b88c1 100644 --- a/docs/evals.md +++ b/docs/evals.md @@ -732,7 +732,7 @@ This can be especially helpful when attempting to write evaluators that make use This allows you to write evaluations that depend on information about which code paths were executed during the call to the task function without needing to manually instrument the code being evaluated, as long as the code being evaluated -is already adequately instrumented with OpenTelemetry. In the case of PydanticAI agents, for example, this can be used +is already adequately instrumented with OpenTelemetry. In the case of Pydantic AI agents, for example, this can be used to ensure specific tools are (or are not) called during the execution of specific cases. Using OpenTelemetry in this way also means that all data used to evaluate the task executions will be accessible in diff --git a/docs/examples/bank-support.md b/docs/examples/bank-support.md index 5d378444cd..8956983bb6 100644 --- a/docs/examples/bank-support.md +++ b/docs/examples/bank-support.md @@ -1,10 +1,10 @@ -Small but complete example of using PydanticAI to build a support agent for a bank. +Small but complete example of using Pydantic AI to build a support agent for a bank. Demonstrates: -* [dynamic system prompt](../agents.md#system-prompts) -* [structured `output_type`](../output.md#structured-output) -* [tools](../tools.md) +- [dynamic system prompt](../agents.md#system-prompts) +- [structured `output_type`](../output.md#structured-output) +- [tools](../tools.md) ## Running the Example diff --git a/docs/examples/index.md b/docs/examples/index.md index ab14bfe75b..b7293cfa23 100644 --- a/docs/examples/index.md +++ b/docs/examples/index.md @@ -1,6 +1,6 @@ # Examples -Examples of how to use PydanticAI and what it can do. +Examples of how to use Pydantic AI and what it can do. ## Usage diff --git a/docs/examples/pydantic-model.md b/docs/examples/pydantic-model.md index 6060611e20..da61595d7b 100644 --- a/docs/examples/pydantic-model.md +++ b/docs/examples/pydantic-model.md @@ -1,10 +1,10 @@ # Pydantic Model -Simple example of using PydanticAI to construct a Pydantic model from a text input. +Simple example of using Pydantic AI to construct a Pydantic model from a text input. Demonstrates: -* [structured `output_type`](../output.md#structured-output) +- [structured `output_type`](../output.md#structured-output) ## Running the Example diff --git a/docs/examples/rag.md b/docs/examples/rag.md index 7591407638..2d4b3300dd 100644 --- a/docs/examples/rag.md +++ b/docs/examples/rag.md @@ -4,12 +4,12 @@ RAG search example. This demo allows you to ask question of the [logfire](https: Demonstrates: -* [tools](../tools.md) -* [agent dependencies](../dependencies.md) -* RAG search +- [tools](../tools.md) +- [agent dependencies](../dependencies.md) +- RAG search This is done by creating a database containing each section of the markdown documentation, then registering -the search tool with the PydanticAI agent. +the search tool with the Pydantic AI agent. Logic for extracting sections from markdown files and a JSON file with that data is available in [this gist](https://gist.github.com/samuelcolvin/4b5bb9bb163b1122ff17e29e48c10992). @@ -34,7 +34,7 @@ With that running and [dependencies installed and environment variables set](./i python/uv-run -m pydantic_ai_examples.rag build ``` -(Note building the database doesn't use PydanticAI right now, instead it uses the OpenAI SDK directly.) +(Note building the database doesn't use Pydantic AI right now, instead it uses the OpenAI SDK directly.) You can then ask the agent a question with: diff --git a/docs/examples/slack-lead-qualifier.md b/docs/examples/slack-lead-qualifier.md index d010ab75a8..b119e6e4bd 100644 --- a/docs/examples/slack-lead-qualifier.md +++ b/docs/examples/slack-lead-qualifier.md @@ -107,7 +107,6 @@ Now when someone new (possibly you with a throwaway email) joins the Slack works ``` !!! note "Deploying to production" - If you'd like to deploy this app into your Modal workspace in a persistent fashion, you can use this command: ```bash diff --git a/docs/examples/sql-gen.md b/docs/examples/sql-gen.md index 243c249779..e1518e883b 100644 --- a/docs/examples/sql-gen.md +++ b/docs/examples/sql-gen.md @@ -1,13 +1,13 @@ # SQL Generation -Example demonstrating how to use PydanticAI to generate SQL queries based on user input. +Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. Demonstrates: -* [dynamic system prompt](../agents.md#system-prompts) -* [structured `output_type`](../output.md#structured-output) -* [output validation](../output.md#output-validator-functions) -* [agent dependencies](../dependencies.md) +- [dynamic system prompt](../agents.md#system-prompts) +- [structured `output_type`](../output.md#structured-output) +- [output validation](../output.md#output-validator-functions) +- [agent dependencies](../dependencies.md) ## Running the Example @@ -16,6 +16,7 @@ The resulting SQL is validated by running it as an `EXPLAIN` query on PostgreSQL ```bash docker run --rm -e POSTGRES_PASSWORD=postgres -p 54320:5432 postgres ``` + _(we run postgres on port `54320` to avoid conflicts with any other postgres instances you may have running)_ With [dependencies installed and environment variables set](./index.md#usage), run: diff --git a/docs/examples/weather-agent.md b/docs/examples/weather-agent.md index c94cce0d69..7961ecf62c 100644 --- a/docs/examples/weather-agent.md +++ b/docs/examples/weather-agent.md @@ -1,11 +1,11 @@ -Example of PydanticAI with multiple tools which the LLM needs to call in turn to answer a question. +Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question. Demonstrates: -* [tools](../tools.md) -* [agent dependencies](../dependencies.md) -* [streaming text responses](../output.md#streaming-text) -* Building a [Gradio](https://www.gradio.app/) UI for the agent +- [tools](../tools.md) +- [agent dependencies](../dependencies.md) +- [streaming text responses](../output.md#streaming-text) +- Building a [Gradio](https://www.gradio.app/) UI for the agent In this case the idea is a "weather" agent — the user can ask for the weather in multiple locations, the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use @@ -15,8 +15,8 @@ the `get_weather` tool to get the weather for those locations. To run this example properly, you might want to add two extra API keys **(Note if either key is missing, the code will fall back to dummy data, so they're not required)**: -* A weather API key from [tomorrow.io](https://www.tomorrow.io/weather-api/) set via `WEATHER_API_KEY` -* A geocoding API key from [geocode.maps.co](https://geocode.maps.co/) set via `GEO_API_KEY` +- A weather API key from [tomorrow.io](https://www.tomorrow.io/weather-api/) set via `WEATHER_API_KEY` +- A geocoding API key from [geocode.maps.co](https://geocode.maps.co/) set via `GEO_API_KEY` With [dependencies installed and environment variables set](./index.md#usage), run: @@ -25,7 +25,6 @@ python/uv-run -m pydantic_ai_examples.weather_agent ``` ## Example Code - ```snippet {path="/examples/pydantic_ai_examples/weather_agent.py"}``` ## Running the UI @@ -44,5 +43,4 @@ python/uv-run -m pydantic_ai_examples.weather_agent_gradio ``` ## UI Code - ```snippet {path="/examples/pydantic_ai_examples/weather_agent_gradio.py"}``` diff --git a/docs/favicon.ico b/docs/favicon.ico index 74301be3a0610cc999e388d00e207be25ff0235e..76d55dd0f192414f18e4604f4a5240521a606a5c 100644 GIT binary patch literal 4286 zcmeHJO=w+36rPyWND*I@YAdF1?t2Z=TBL=vR3&%?TMC6zy6U1oWa+}vWhp2I0*1y_ zvS>G5RTr*A5haVbcp@TcL4<@t+UBos=G~h-l9!4mHaD4m=iWOv^M-emzWl80eaxFP z=bZVzGw00Ay^J;B=Qxc1*|#l>y}=mUj|LU{2#xgN1KT$*MlDi2|KDf^B$o$#M_i>w z@cL#k*WkQY&1dtQV=<;Y6to+AS221EQf2ax+iU1YQKf>=Z^VW!0b{R=Z$*&gPW%loxBk@Mfc zeZ}Yg9*WoF#rG4xZ1DrqO|?$F2Hhp#{OEJ}Ut(X4bBf~ja>e7|YzlF*;yiGr?Th(~ z5l*SIlAU~&e9fm8ziFpAJ}-5~u`Js7SdXAs(f?`K{)POVhnQybvl$c5me=Dq{RxXO z2^)<#;u?uz=Q1hSk zevE1*C&`9-%t=S`G+wQ*m%rThs1^9-n4>7(7P~YS_0x|0Fg{xfdHEHwHJWx?u7{t_ z7tZo85w}?Ovc->WJQ`!|J>(O|_*pFPJjHp{yzB6hw*YbgIe@cZ77H4thM;bH>@YU4}cMK&l9$Z zktW0cV#uGI)37FG#C5xrE1p*EYKP4a&PICoAl^RYo9>pXd)Uj#3u%0{^tiXawlR`% zl$QZL2IqiYr`A>Vz^cWumag1Zeg)^@L4w1@5Z3=V8C&bqdieZ}*oTraFbHcct~A?B zULFBnw0H2%CmhN@)d8)dIne(T-t)N6=lFOFk;D(t>5Z?oR7bRb^aA{(f5(_=ur55ntW4;5%H$jxz)TTo|pID1^+^{ zUNM5T565d?Ae-sCU9y?)m|OWd@#%K)*|;_KTh)HF{kLN*+vMxGkM7`(86tT*IQ|cq zNN($|SusbMqUrgQ2x&HGJHYL&;LP#_ME~#bxY?-^E#bXzgnc)%f58 z_^>j^JePfFAMyWNAkQ3c+P}({4f&wDVENO0qt==G?PO^@h)wYr=X2-=&kysRmTuQz za@~Ei*HHssgy>UHjI8-=-M=jx_1Q20k0{YU+!fD2_5; z;O)E{TX6>xk7>sFy@vbl#}6Ik9yWuCb@fi~65{3L5kD(BWEY^hNsw{}7vcKB14tzlZ1S`QYjLMjgHv z{YvScq+T0|URC|w7xkH`i&+)y1Nf1T6_7j)%|O_^bUm=4A=nL4P5ySj@%W zIMeg-<E2*(8`*9qc}Jea#+T2@xN@yaQ3TM zEeGxJeOcZ{jnSfUVq zq*@erbZqdRi~gYZ_vHh92Azr6!@Awg=&d}kYW#J%$!YG9hvCR&7dFF@D=zY6&M5eg zI#`oyF5-dpK>oIjI%}w7nk}#mVwqg6_BUmN&kolY6a zed_Hz$MT!#_xs+Ru~ECoKj*B5r(*$~qAZM}rmjV=t80<#I{qb+;moh&dN1i(f)4dj zTl~PE50}R4xH`KoML|SEQn?51ow;=wDIq{sa&z9z7x09 zhdoo}vB*?;QQXxs@Uu^h*$Ru#o4uRa0N0a+G|e~nRL`e+KA`@2ouaw0Bz|^j=A=y@ znqf8LG-Yfu-anDupNKN$2I8%{1XI1KXS3je;m0?qoMf3^RXWRJ)PjCSj zcjQ`!-?;A%aW(Uk?Sr{A?;G6H%unGCd$#}P4BGH_D_??j$@;C{K8~mV=C;z=PtF)- ztb#UbH>=c>6G~NuQkAg3{=T5p;-XUL&Z*r{E?WuZ!nsf`WW8|FdM1=hMJ;QT1eK2} z;ob$SGqe!Ob*x*~k;+0xxMEr1>T`Xf+Sw(lU7eyj`n;&FU!ZkFR2Mo#)v`o2loQo( KMLfywp7;kK bool: # (3)! 3. `run_node` is a pure function that doesn't need access to any other process state to run the next node of the graph, except the ID of the run. 4. Call [`graph.iter_from_persistence()`][pydantic_graph.graph.Graph.iter_from_persistence] create a [`GraphRun`][pydantic_graph.graph.GraphRun] object that will run the next node of the graph from the state stored in persistence. This will return either a node or an `End` object. 5. [`graph.run()`][pydantic_graph.graph.Graph.run] will return either a [node][pydantic_graph.nodes.BaseNode] or an [`End`][pydantic_graph.nodes.End] object. -5. Check if the node is an [`End`][pydantic_graph.nodes.End] object, if it is, the graph run is complete. +6. Check if the node is an [`End`][pydantic_graph.nodes.End] object, if it is, the graph run is complete. _(This example is complete, it can be run "as is" with Python 3.10+ — you'll need to add `asyncio.run(main())` to run `main`)_ @@ -648,103 +648,107 @@ In this example, an AI asks the user a question, the user provides an answer, th Instead of running the entire graph in a single process invocation, we run the graph by running the process repeatedly, optionally providing an answer to the question as a command line argument. ??? example "`ai_q_and_a_graph.py` — `question_graph` definition" - ```python {title="ai_q_and_a_graph.py" noqa="I001" py="3.10"} - from __future__ import annotations as _annotations - - from dataclasses import dataclass, field - - from pydantic import BaseModel - from pydantic_graph import ( - BaseNode, - End, - Graph, - GraphRunContext, - ) +```python {title="ai_q_and_a_graph.py" noqa="I001" py="3.10"} +from __future__ import annotations as _annotations - from pydantic_ai import Agent, format_as_xml - from pydantic_ai.messages import ModelMessage +from typing import Annotated +from pydantic_graph import Edge +from dataclasses import dataclass, field +from pydantic import BaseModel +from pydantic_graph import ( + BaseNode, + End, + Graph, + GraphRunContext, +) +from pydantic_ai import Agent, format_as_xml +from pydantic_ai.messages import ModelMessage - ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True) +ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True) - @dataclass - class QuestionState: - question: str | None = None - ask_agent_messages: list[ModelMessage] = field(default_factory=list) - evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) +@dataclass +class QuestionState: + question: str | None = None + ask_agent_messages: list[ModelMessage] = field(default_factory=list) + evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) - @dataclass - class Ask(BaseNode[QuestionState]): - async def run(self, ctx: GraphRunContext[QuestionState]) -> Answer: - result = await ask_agent.run( - 'Ask a simple question with a single correct answer.', - message_history=ctx.state.ask_agent_messages, - ) - ctx.state.ask_agent_messages += result.new_messages() - ctx.state.question = result.output - return Answer(result.output) +@dataclass +class Ask(BaseNode[QuestionState]): + """Generate question using GPT-4o.""" + docstring_notes = True + async def run( + self, ctx: GraphRunContext[QuestionState] + ) -> Annotated[Answer, Edge(label='Ask the question')]: + result = await ask_agent.run( + 'Ask a simple question with a single correct answer.', + message_history=ctx.state.ask_agent_messages, + ) + ctx.state.ask_agent_messages += result.new_messages() + ctx.state.question = result.output + return Answer(result.output) - @dataclass - class Answer(BaseNode[QuestionState]): - question: str +@dataclass +class Answer(BaseNode[QuestionState]): + question: str - async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: - answer = input(f'{self.question}: ') - return Evaluate(answer) + async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: + answer = input(f'{self.question}: ') + return Evaluate(answer) - class EvaluationResult(BaseModel, use_attribute_docstrings=True): - correct: bool - """Whether the answer is correct.""" - comment: str - """Comment on the answer, reprimand the user if the answer is wrong.""" +class EvaluationResult(BaseModel, use_attribute_docstrings=True): + correct: bool + """Whether the answer is correct.""" + comment: str + """Comment on the answer, reprimand the user if the answer is wrong.""" - evaluate_agent = Agent( - 'openai:gpt-4o', - output_type=EvaluationResult, - system_prompt='Given a question and answer, evaluate if the answer is correct.', - ) +evaluate_agent = Agent( + 'openai:gpt-4o', + output_type=EvaluationResult, + system_prompt='Given a question and answer, evaluate if the answer is correct.', +) - @dataclass - class Evaluate(BaseNode[QuestionState, None, str]): - answer: str +@dataclass +class Evaluate(BaseNode[QuestionState, None, str]): + answer: str - async def run( - self, - ctx: GraphRunContext[QuestionState], - ) -> End[str] | Reprimand: - assert ctx.state.question is not None - result = await evaluate_agent.run( - format_as_xml({'question': ctx.state.question, 'answer': self.answer}), - message_history=ctx.state.evaluate_agent_messages, - ) - ctx.state.evaluate_agent_messages += result.new_messages() - if result.output.correct: - return End(result.output.comment) - else: - return Reprimand(result.output.comment) + async def run( + self, + ctx: GraphRunContext[QuestionState], + ) -> Annotated[End[str], Edge(label='success')] | Reprimand: + assert ctx.state.question is not None + result = await evaluate_agent.run( + format_as_xml({'question': ctx.state.question, 'answer': self.answer}), + message_history=ctx.state.evaluate_agent_messages, + ) + ctx.state.evaluate_agent_messages += result.new_messages() + if result.output.correct: + return End(result.output.comment) + else: + return Reprimand(result.output.comment) - @dataclass - class Reprimand(BaseNode[QuestionState]): - comment: str +@dataclass +class Reprimand(BaseNode[QuestionState]): + comment: str - async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: - print(f'Comment: {self.comment}') - ctx.state.question = None - return Ask() + async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: + print(f'Comment: {self.comment}') + ctx.state.question = None + return Ask() - question_graph = Graph( - nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState - ) - ``` +question_graph = Graph( + nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState +) +``` - _(This example is complete, it can be run "as is" with Python 3.10+)_ +_(This example is complete, it can be run "as is" with Python 3.10+)_ ```python {title="ai_q_and_a_run.py" noqa="I001" py="3.10" requires="ai_q_and_a_graph.py"} import sys @@ -801,7 +805,7 @@ For a complete example of this graph, see the [question graph example](examples/ ## Dependency Injection -As with PydanticAI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] field. +As with Pydantic AI, `pydantic-graph` supports dependency injection via a generic parameter on [`Graph`][pydantic_graph.graph.Graph] and [`BaseNode`][pydantic_graph.nodes.BaseNode], and the [`GraphRunContext.deps`][pydantic_graph.nodes.GraphRunContext.deps] field. As an example of dependency injection, let's modify the `DivisibleBy5` example [above](#graph) to use a [`ProcessPoolExecutor`][concurrent.futures.ProcessPoolExecutor] to run the compute load in a separate process (this is a contrived example, `ProcessPoolExecutor` wouldn't actually improve performance in this example): @@ -881,30 +885,37 @@ Pydantic Graph can generate [mermaid](https://mermaid.js.org/) [`stateDiagram-v2 These diagrams can be generated with: -* [`Graph.mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] to generate the mermaid code for a graph -* [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) -* [`Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file +- [`Graph.mermaid_code`][pydantic_graph.graph.Graph.mermaid_code] to generate the mermaid code for a graph +- [`Graph.mermaid_image`][pydantic_graph.graph.Graph.mermaid_image] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) +- [`Graph.mermaid_save`][pydantic_graph.graph.Graph.mermaid_save] to generate an image of the graph using [mermaid.ink](https://mermaid.ink/) and save it to a file Beyond the diagrams shown above, you can also customize mermaid diagrams with the following options: -* [`Edge`][pydantic_graph.nodes.Edge] allows you to apply a label to an edge -* [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes -* The [`highlighted_nodes`][pydantic_graph.graph.Graph.mermaid_code] parameter allows you to highlight specific node(s) in the diagram +- [`Edge`][pydantic_graph.nodes.Edge] allows you to apply a label to an edge +- [`BaseNode.docstring_notes`][pydantic_graph.nodes.BaseNode.docstring_notes] and [`BaseNode.get_note`][pydantic_graph.nodes.BaseNode.get_note] allows you to add notes to nodes +- The [`highlighted_nodes`][pydantic_graph.graph.Graph.mermaid_code] parameter allows you to highlight specific node(s) in the diagram Putting that together, we can edit the last [`ai_q_and_a_graph.py`](#example-human-in-the-loop) example to: -* add labels to some edges -* add a note to the `Ask` node -* highlight the `Answer` node -* save the diagram as a `PNG` image to file +- add labels to some edges +- add a note to the `Ask` node +- highlight the `Answer` node +- save the diagram as a `PNG` image to file ```python {title="ai_q_and_a_graph_extra.py" test="skip" lint="skip" hl_lines="2 4 10-11 14 26 31"} -... from typing import Annotated from pydantic_graph import BaseNode, End, Graph, GraphRunContext, Edge -... +ask_agent = Agent('openai:gpt-4o', output_type=str, instrument=True) + + +@dataclass +class QuestionState: + question: str | None = None + ask_agent_messages: list[ModelMessage] = field(default_factory=list) + evaluate_agent_messages: list[ModelMessage] = field(default_factory=list) + @dataclass class Ask(BaseNode[QuestionState]): @@ -913,23 +924,71 @@ class Ask(BaseNode[QuestionState]): async def run( self, ctx: GraphRunContext[QuestionState] ) -> Annotated[Answer, Edge(label='Ask the question')]: - ... + result = await ask_agent.run( + 'Ask a simple question with a single correct answer.', + message_history=ctx.state.ask_agent_messages, + ) + ctx.state.ask_agent_messages += result.new_messages() + ctx.state.question = result.output + return Answer(result.output) -... @dataclass -class Evaluate(BaseNode[QuestionState]): +class Answer(BaseNode[QuestionState]): + question: str + + async def run(self, ctx: GraphRunContext[QuestionState]) -> Evaluate: + answer = input(f'{self.question}: ') + return Evaluate(answer) + + +class EvaluationResult(BaseModel, use_attribute_docstrings=True): + correct: bool + """Whether the answer is correct.""" + comment: str + """Comment on the answer, reprimand the user if the answer is wrong.""" + + +evaluate_agent = Agent( + 'openai:gpt-4o', + output_type=EvaluationResult, + system_prompt='Given a question and answer, evaluate if the answer is correct.', +) + + +@dataclass +class Evaluate(BaseNode[QuestionState, None, str]): answer: str async def run( - self, - ctx: GraphRunContext[QuestionState], + self, + ctx: GraphRunContext[QuestionState], ) -> Annotated[End[str], Edge(label='success')] | Reprimand: - ... + assert ctx.state.question is not None + result = await evaluate_agent.run( + format_as_xml({'question': ctx.state.question, 'answer': self.answer}), + message_history=ctx.state.evaluate_agent_messages, + ) + ctx.state.evaluate_agent_messages += result.new_messages() + if result.output.correct: + return End(result.output.comment) + else: + return Reprimand(result.output.comment) + + +@dataclass +class Reprimand(BaseNode[QuestionState]): + comment: str -... + async def run(self, ctx: GraphRunContext[QuestionState]) -> Ask: + print(f'Comment: {self.comment}') + ctx.state.question = None + return Ask() -question_graph.mermaid_save('image.png', highlighted_nodes=[Answer]) + +question_graph = Graph( + nodes=(Ask, Answer, Evaluate, Reprimand), state_type=QuestionState +) ``` _(This example is not complete and cannot be run directly)_ @@ -965,6 +1024,7 @@ You can specify the direction of the state diagram using one of the following va - `'BT'`: Bottom to top, the diagram flows vertically from bottom to top. Here is an example of how to do this using 'Left to Right' (LR) instead of the default 'Top to Bottom' (TB): + ```py {title="vending_machine_diagram.py" py="3.10" requires="vending_machine.py"} from vending_machine import InsertCoin, vending_machine_graph diff --git a/docs/help.md b/docs/help.md index 4e6fd3e6d6..ff389ea22e 100644 --- a/docs/help.md +++ b/docs/help.md @@ -1,16 +1,16 @@ # Getting Help -If you need help getting started with PydanticAI or with advanced usage, the following sources may be useful. +If you need help getting started with Pydantic AI or with advanced usage, the following sources may be useful. ## :simple-slack: Slack -Join the `#pydantic-ai` channel in the [Pydantic Slack][slack] to ask questions, get help, and chat about PydanticAI. There's also channels for Pydantic, Logfire, and FastUI. +Join the `#pydantic-ai` channel in the [Pydantic Slack][slack] to ask questions, get help, and chat about Pydantic AI. There's also channels for Pydantic, Logfire, and FastUI. If you're on a [Logfire][logfire] Pro plan, you can also get a dedicated private slack collab channel with us. ## :simple-github: GitHub Issues -The [PydanticAI GitHub Issues][github-issues] are a great place to ask questions and give us feedback. +The [Pydantic AI GitHub Issues][github-issues] are a great place to ask questions and give us feedback. [slack]: https://logfire.pydantic.dev/docs/join-slack/ [github-issues]: https://github.com/pydantic/pydantic-ai/issues diff --git a/docs/img/logo-white.svg b/docs/img/logo-white.svg index e38e6723da..8b223f9c3a 100644 --- a/docs/img/logo-white.svg +++ b/docs/img/logo-white.svg @@ -1,17 +1,3 @@ - - - - - - - - - - - - - - - - + + diff --git a/docs/img/pydantic-ai-dark.svg b/docs/img/pydantic-ai-dark.svg index f24d822d94..374f66850b 100644 --- a/docs/img/pydantic-ai-dark.svg +++ b/docs/img/pydantic-ai-dark.svg @@ -1,34 +1,4 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + diff --git a/docs/img/pydantic-ai-light.svg b/docs/img/pydantic-ai-light.svg index e94b9f5068..adf7c5d0bd 100644 --- a/docs/img/pydantic-ai-light.svg +++ b/docs/img/pydantic-ai-light.svg @@ -1,34 +1,4 @@ - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - + + + diff --git a/docs/index.md b/docs/index.md index 3290e7924f..2137deeedc 100644 --- a/docs/index.md +++ b/docs/index.md @@ -2,45 +2,45 @@ --8<-- "docs/.partials/index-header.html" -FastAPI revolutionized web development by offering an innovative and ergonomic design, built on the foundation of [Pydantic](https://docs.pydantic.dev). +FastAPI revolutionized web development by offering an innovative and ergonomic design, built on the foundation of [Pydantic Validation](https://docs.pydantic.dev). -Similarly, virtually every agent framework and LLM library in Python uses Pydantic, yet when we began to use LLMs in [Pydantic Logfire](https://pydantic.dev/logfire), we couldn't find anything that gave us the same feeling. +Similarly, virtually every agent framework and LLM library in Python uses Pydantic Validation, yet when we began to use LLMs in [Pydantic Logfire](https://pydantic.dev/logfire), we couldn't find anything that gave us the same feeling. -We built PydanticAI with one simple aim: to bring that FastAPI feeling to GenAI app development. +We built Pydantic AI with one simple aim: to bring that FastAPI feeling to GenAI app development. -## Why use PydanticAI +## Why use Pydantic AI -* __Built by the Pydantic Team__: -Built by the team behind [Pydantic](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more). +- **Built by the Pydantic Team**: + Built by the team behind [Pydantic Validation](https://docs.pydantic.dev/latest/) (the validation layer of the OpenAI SDK, the Anthropic SDK, LangChain, LlamaIndex, AutoGPT, Transformers, CrewAI, Instructor and many more). -* __Model-agnostic__: -Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](models/index.md). +- **Model-agnostic**: + Supports OpenAI, Anthropic, Gemini, Deepseek, Ollama, Groq, Cohere, and Mistral, and there is a simple interface to implement support for [other models](models/index.md). -* __Pydantic Logfire Integration__: -Seamlessly [integrates](logfire.md) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications. +- **Pydantic Logfire Integration**: + Seamlessly [integrates](logfire.md) with [Pydantic Logfire](https://pydantic.dev/logfire) for real-time debugging, performance monitoring, and behavior tracking of your LLM-powered applications. -* __Type-safe__: -Designed to make [type checking](agents.md#static-type-checking) as powerful and informative as possible for you. +- **Type-safe**: + Designed to make [type checking](agents.md#static-type-checking) as powerful and informative as possible for you. -* __Python-centric Design__: -Leverages Python's familiar control flow and agent composition to build your AI-driven projects, making it easy to apply standard Python best practices you'd use in any other (non-AI) project. +- **Python-centric Design**: + Leverages Python's familiar control flow and agent composition to build your AI-driven projects, making it easy to apply standard Python best practices you'd use in any other (non-AI) project. -* __Structured Responses__: -Harnesses the power of [Pydantic](https://docs.pydantic.dev/latest/) to [validate and structure](output.md#structured-output) model outputs, ensuring responses are consistent across runs. +- **Structured Responses**: + Harnesses the power of [Pydantic Validation](https://docs.pydantic.dev/latest/) to [validate and structure](output.md#structured-output) model outputs, ensuring responses are consistent across runs. -* __Dependency Injection System__: -Offers an optional [dependency injection](dependencies.md) system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [tools](tools.md) and [output validators](output.md#output-validator-functions). -This is useful for testing and eval-driven iterative development. +- **Dependency Injection System**: + Offers an optional [dependency injection](dependencies.md) system to provide data and services to your agent's [system prompts](agents.md#system-prompts), [tools](tools.md) and [output validators](output.md#output-validator-functions). + This is useful for testing and eval-driven iterative development. -* __Streamed Responses__: -Provides the ability to [stream](output.md#streamed-results) LLM responses continuously, with immediate validation, ensuring real time access to validated outputs. +- **Streamed Responses**: + Provides the ability to [stream](output.md#streamed-results) LLM responses continuously, with immediate validation, ensuring real time access to validated outputs. -* __Graph Support__: -[Pydantic Graph](graph.md) provides a powerful way to define graphs using typing hints, this is useful in complex applications where standard control flow can degrade to spaghetti code. +- **Graph Support**: + [Pydantic Graph](graph.md) provides a powerful way to define graphs using typing hints, this is useful in complex applications where standard control flow can degrade to spaghetti code. ## Hello World Example -Here's a minimal example of PydanticAI: +Here's a minimal example of Pydantic AI: ```python {title="hello_world.py"} from pydantic_ai import Agent @@ -63,13 +63,13 @@ The first known use of "hello, world" was in a 1974 textbook about the C program _(This example is complete, it can be run "as is")_ -The exchange should be very short: PydanticAI will send the system prompt and the user query to the LLM, the model will return a text response. +The exchange should be very short: Pydantic AI will send the system prompt and the user query to the LLM, the model will return a text response. Not very interesting yet, but we can easily add "tools", dynamic system prompts, and structured responses to build more powerful agents. ## Tools & Dependency Injection Example -Here is a concise example using PydanticAI to build a support agent for a bank: +Here is a concise example using Pydantic AI to build a support agent for a bank: ```python {title="bank_support.py"} from dataclasses import dataclass @@ -140,7 +140,7 @@ async def main(): 1. This [agent](agents.md) will act as first-tier support in a bank. Agents are generic in the type of dependencies they accept and the type of output they return. In this case, the support agent has type `#!python Agent[SupportDependencies, SupportOutput]`. 2. Here we configure the agent to use [OpenAI's GPT-4o model](api/models/openai.md), you can also set the model when running the agent. -3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [tool](tools.md) functions. PydanticAI's system of dependency injection provides a [type-safe](agents.md#static-type-checking) way to customise the behavior of your agents, and can be especially useful when running [unit tests](testing.md) and evals. +3. The `SupportDependencies` dataclass is used to pass data, connections, and logic into the model that will be needed when running [system prompt](agents.md#system-prompts) and [tool](tools.md) functions. Pydantic AI's system of dependency injection provides a [type-safe](agents.md#static-type-checking) way to customise the behavior of your agents, and can be especially useful when running [unit tests](testing.md) and evals. 4. Static [system prompts](agents.md#system-prompts) can be registered with the [`system_prompt` keyword argument][pydantic_ai.Agent.__init__] to the agent. 5. Dynamic [system prompts](agents.md#system-prompts) can be registered with the [`@agent.system_prompt`][pydantic_ai.Agent.system_prompt] decorator, and can make use of dependency injection. Dependencies are carried via the [`RunContext`][pydantic_ai.tools.RunContext] argument, which is parameterized with the `deps_type` from above. If the type annotation here is wrong, static type checkers will catch it. 6. [`tool`](tools.md) let you register functions which the LLM may call while responding to a user. Again, dependencies are carried via [`RunContext`][pydantic_ai.tools.RunContext], any other arguments become the tool schema passed to the LLM. Pydantic is used to validate these arguments, and errors are passed back to the LLM so it can retry. @@ -197,7 +197,7 @@ See [Monitoring and Performance](logfire.md) to learn more. ## llms.txt -The PydanticAI documentation is available in the [llms.txt](https://llmstxt.org/) format. +The Pydantic AI documentation is available in the [llms.txt](https://llmstxt.org/) format. This format is defined in Markdown and suited for large language models. Two formats are available: @@ -208,15 +208,14 @@ Two formats are available: - [llms-full.txt](https://ai.pydantic.dev/llms-full.txt): Similar to the `llms.txt` file, but every link content is included. Note that this file may be too large for some LLMs. -As of today, these files *cannot* be natively leveraged by LLM frameworks or IDEs. Alternatively, +As of today, these files _cannot_ be natively leveraged by LLM frameworks or IDEs. Alternatively, an [MCP server](https://modelcontextprotocol.io/) can be implemented to properly parse the `llms.txt` file. - ## Next Steps -To try PydanticAI yourself, follow the instructions [in the examples](examples/index.md). +To try Pydantic AI yourself, follow the instructions [in the examples](examples/index.md). -Read the [docs](agents.md) to learn more about building applications with PydanticAI. +Read the [docs](agents.md) to learn more about building applications with Pydantic AI. -Read the [API Reference](api/agent.md) to understand PydanticAI's interface. +Read the [API Reference](api/agent.md) to understand Pydantic AI's interface. diff --git a/docs/input.md b/docs/input.md index 0d6278dae5..700d7f05e6 100644 --- a/docs/input.md +++ b/docs/input.md @@ -2,7 +2,6 @@ Some LLMs are now capable of understanding audio, video, image and document content. - ## Image Input !!! info @@ -105,7 +104,7 @@ print(result.output) ## User-side download vs. direct file URL -As a general rule, when you provide a URL using any of `ImageUrl`, `AudioUrl`, `VideoUrl` or `DocumentUrl`, PydanticAI downloads the file content and then sends it as part of the API request. +As a general rule, when you provide a URL using any of `ImageUrl`, `AudioUrl`, `VideoUrl` or `DocumentUrl`, Pydantic AI downloads the file content and then sends it as part of the API request. The situation is different for certain models: @@ -113,12 +112,12 @@ The situation is different for certain models: - [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] and [`GoogleModel`][pydantic_ai.models.google.GoogleModel] on Vertex AI: any URL provided using `ImageUrl`, `AudioUrl`, `VideoUrl`, or `DocumentUrl` is sent as-is in the API request and no data is downloaded beforehand. - See the [Gemini API docs for Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#filedata) to learn more about supported URLs, formats and limitations: + See the [Gemini API docs for Vertex AI](https://cloud.google.com/vertex-ai/generative-ai/docs/model-reference/inference#filedata) to learn more about supported URLs, formats and limitations: - - Cloud Storage bucket URIs (with protocol `gs://`) - - Public HTTP(S) URLs - - Public YouTube video URL (maximum one URL per request) + - Cloud Storage bucket URIs (with protocol `gs://`) + - Public HTTP(S) URLs + - Public YouTube video URL (maximum one URL per request) - However, because of crawling restrictions, it may happen that Gemini can't access certain URLs. In that case, you can instruct PydanticAI to download the file content and send that instead of the URL by setting the boolean flag `force_download` to `True`. This attribute is available on all objects that inherit from [`FileUrl`][pydantic_ai.messages.FileUrl]. + However, because of crawling restrictions, it may happen that Gemini can't access certain URLs. In that case, you can instruct Pydantic AI to download the file content and send that instead of the URL by setting the boolean flag `force_download` to `True`. This attribute is available on all objects that inherit from [`FileUrl`][pydantic_ai.messages.FileUrl]. - [`GeminiModel`][pydantic_ai.models.gemini.GeminiModel] and [`GoogleModel`][pydantic_ai.models.google.GoogleModel] on GLA: YouTube video URLs are sent directly in the request to the model. diff --git a/docs/install.md b/docs/install.md index 9b4b473559..d1a6909c84 100644 --- a/docs/install.md +++ b/docs/install.md @@ -1,6 +1,6 @@ # Installation -PydanticAI is available on PyPI as [`pydantic-ai`](https://pypi.org/project/pydantic-ai/) so installation is as simple as: +Pydantic AI is available on PyPI as [`pydantic-ai`](https://pypi.org/project/pydantic-ai/) so installation is as simple as: ```bash pip/uv-add pydantic-ai @@ -9,13 +9,13 @@ pip/uv-add pydantic-ai (Requires Python 3.9+) This installs the `pydantic_ai` package, core dependencies, and libraries required to use all the models -included in PydanticAI. If you want to use a specific model, you can install the ["slim"](#slim-install) version of PydanticAI. +included in Pydantic AI. If you want to use a specific model, you can install the ["slim"](#slim-install) version of Pydantic AI. ## Use with Pydantic Logfire -PydanticAI has an excellent (but completely optional) integration with [Pydantic Logfire](https://pydantic.dev/logfire) to help you view and understand agent runs. +Pydantic AI has an excellent (but completely optional) integration with [Pydantic Logfire](https://pydantic.dev/logfire) to help you view and understand agent runs. -To use Logfire with PydanticAI, install `pydantic-ai` or `pydantic-ai-slim` with the `logfire` optional group: +To use Logfire with Pydantic AI, install `pydantic-ai` or `pydantic-ai-slim` with the `logfire` optional group: ```bash pip/uv-add "pydantic-ai[logfire]" diff --git a/docs/logfire.md b/docs/logfire.md index f0f4a76926..209ac19f76 100644 --- a/docs/logfire.md +++ b/docs/logfire.md @@ -1,4 +1,4 @@ -# Debugging and Monitoring +# Pydantic Logfire Debugging and Monitoring Applications that use LLMs have some challenges that are well known and understood: LLMs are **slow**, **unreliable** and **expensive**. @@ -15,13 +15,13 @@ LLM Observability tools that just let you understand how your model is performin ## Pydantic Logfire -[Pydantic Logfire](https://pydantic.dev/logfire) is an observability platform developed by the team who created and maintain Pydantic and PydanticAI. Logfire aims to let you understand your entire application: Gen AI, classic predictive AI, HTTP traffic, database queries and everything else a modern application needs, all using OpenTelemetry. +[Pydantic Logfire](https://pydantic.dev/logfire) is an observability platform developed by the team who created and maintain Pydantic Validation and Pydantic AI. Logfire aims to let you understand your entire application: Gen AI, classic predictive AI, HTTP traffic, database queries and everything else a modern application needs, all using OpenTelemetry. !!! tip "Pydantic Logfire is a commercial product" Logfire is a commercially supported, hosted platform with an extremely generous and perpetual [free tier](https://pydantic.dev/pricing/). You can sign up and start using Logfire in a couple of minutes. Logfire can also be self-hosted on the enterprise tier. -PydanticAI has built-in (but optional) support for Logfire. That means if the `logfire` package is installed and configured and agent instrumentation is enabled then detailed information about agent runs is sent to Logfire. Otherwise there's virtually no overhead and nothing is sent. +Pydantic AI has built-in (but optional) support for Logfire. That means if the `logfire` package is installed and configured and agent instrumentation is enabled then detailed information about agent runs is sent to Logfire. Otherwise there's virtually no overhead and nothing is sent. Here's an example showing details of running the [Weather Agent](examples/weather-agent.md) in Logfire: @@ -53,7 +53,7 @@ py-cli logfire projects new This will write to a `.logfire` directory in the current working directory, which the Logfire SDK will use for configuration at run time. -With that, you can start using Logfire to instrument PydanticAI code: +With that, you can start using Logfire to instrument Pydantic AI code: ```python {title="instrument_pydantic_ai.py" hl_lines="1 5 6"} import logfire @@ -72,7 +72,7 @@ The first known use of "hello, world" was in a 1974 textbook about the C program ``` 1. [`logfire.configure()`][logfire.configure] configures the SDK, by default it will find the write token from the `.logfire` directory, but you can also pass a token directly. -2. [`logfire.instrument_pydantic_ai()`][logfire.Logfire.instrument_pydantic_ai] enables instrumentation of PydanticAI. +2. [`logfire.instrument_pydantic_ai()`][logfire.Logfire.instrument_pydantic_ai] enables instrumentation of Pydantic AI. 3. Since we've enabled instrumentation, a trace will be generated for each run, with spans emitted for models calls and tool function execution _(This example is complete, it can be run "as is")_ @@ -81,31 +81,30 @@ Which will display in Logfire thus: ![Logfire Simple Agent Run](img/logfire-simple-agent.png) -The [logfire documentation](https://logfire.pydantic.dev/docs/) has more details on how to use Logfire, +The [Logfire documentation](https://logfire.pydantic.dev/docs/) has more details on how to use Logfire, including how to instrument other libraries like [HTTPX](https://logfire.pydantic.dev/docs/integrations/http-clients/httpx/) and [FastAPI](https://logfire.pydantic.dev/docs/integrations/web-frameworks/fastapi/). Since Logfire is built on [OpenTelemetry](https://opentelemetry.io/), you can use the Logfire Python SDK to send data to any OpenTelemetry collector, see [below](#using-opentelemetry). ### Debugging -To demonstrate how Logfire can let you visualise the flow of a PydanticAI run, here's the view you get from Logfire while running the [chat app examples](examples/chat-app.md): +To demonstrate how Logfire can let you visualise the flow of a Pydantic AI run, here's the view you get from Logfire while running the [chat app examples](examples/chat-app.md): {{ video('a764aff5840534dc77eba7d028707bfa', 25) }} ### Monitoring Performance -We can also query data with SQL in Logfire to monitor the performance of an application. Here's a real world example of using Logfire to monitor PydanticAI runs inside Logfire itself: +We can also query data with SQL in Logfire to monitor the performance of an application. Here's a real world example of using Logfire to monitor Pydantic AI runs inside Logfire itself: -![Logfire monitoring PydanticAI](img/logfire-monitoring-pydanticai.png) +![Logfire monitoring Pydantic AI](img/logfire-monitoring-pydanticai.png) ### Monitoring HTTP Requests -!!! tip ""F**k you, show me the prompt."" +!!! tip "\"F**k you, show me the prompt.\"" As per Hamel Husain's influential 2024 blog post ["Fuck You, Show Me The Prompt."](https://hamel.dev/blog/posts/prompt/) (bear with the capitalization, the point is valid), it's often useful to be able to view the raw HTTP requests and responses made to model providers. -To observe raw HTTP requests made to model providers, you can use `logfire`'s [HTTPX instrumentation](https://logfire.pydantic.dev/docs/integrations/http-clients/httpx/) since all provider SDKs use the [HTTPX](https://www.python-httpx.org/) library internally. - + To observe raw HTTP requests made to model providers, you can use Logfire's [HTTPX instrumentation](https://logfire.pydantic.dev/docs/integrations/http-clients/httpx/) since all provider SDKs use the [HTTPX](https://www.python-httpx.org/) library internally. === "With HTTP instrumentation" @@ -147,17 +146,17 @@ To observe raw HTTP requests made to model providers, you can use `logfire`'s [H ## Using OpenTelemetry -PydanticAI's instrumentation uses [OpenTelemetry](https://opentelemetry.io/) (OTel), which Logfire is based on. +Pydantic AI's instrumentation uses [OpenTelemetry](https://opentelemetry.io/) (OTel), which Logfire is based on. -This means you can debug and monitor PydanticAI with any OpenTelemetry backend. +This means you can debug and monitor Pydantic AI with any OpenTelemetry backend. -PydanticAI follows the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/), so while we think you'll have the best experience using the Logfire platform :wink:, you should be able to use any OTel service with GenAI support. +Pydantic AI follows the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/), so while we think you'll have the best experience using the Logfire platform :wink:, you should be able to use any OTel service with GenAI support. ### Logfire with an alternative OTel backend You can use the Logfire SDK completely freely and send the data to any OpenTelemetry backend. -Here's an example of configuring the Logfire library to send data to the excellent [otel-tui](https://github.com/ymtdzzz/otel-tui) — an open source terminal based OTel backend and viewer (no association with Pydantic). +Here's an example of configuring the Logfire library to send data to the excellent [otel-tui](https://github.com/ymtdzzz/otel-tui) — an open source terminal based OTel backend and viewer (no association with Pydantic Validation). Run `otel-tui` with docker (see [the otel-tui readme](https://github.com/ymtdzzz/otel-tui) for more instructions): @@ -201,7 +200,7 @@ For more information on using the Logfire SDK to send data to alternative backen ### OTel without Logfire -You can also emit OpenTelemetry data from PydanticAI without using Logfire at all. +You can also emit OpenTelemetry data from Pydantic AI without using Logfire at all. To do this, you'll need to install and configure the OpenTelemetry packages you need. To run the following examples, use @@ -261,7 +260,7 @@ The following providers have dedicated documentation on Pydantic AI: ### Configuring data format -PydanticAI follows the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/), with one caveat. The semantic conventions specify that messages should be captured as individual events (logs) that are children of the request span. By default, PydanticAI instead collects these events into a JSON array which is set as a single large attribute called `events` on the request span. To change this, use `event_mode='logs'`: +Pydantic AI follows the [OpenTelemetry Semantic Conventions for Generative AI systems](https://opentelemetry.io/docs/specs/semconv/gen-ai/), with one caveat. The semantic conventions specify that messages should be captured as individual events (logs) that are children of the request span. By default, Pydantic AI instead collects these events into a JSON array which is set as a single large attribute called `events` on the request span. To change this, use `event_mode='logs'`: ```python {title="instrumentation_settings_event_mode.py"} import logfire @@ -327,9 +326,9 @@ Agent.instrument_all(instrumentation_settings) ### Excluding prompts and completions -For privacy and security reasons, you may want to monitor your agent's behavior and performance without exposing sensitive user data or proprietary prompts in your observability platform. PydanticAI allows you to exclude the actual content from instrumentation events while preserving the structural information needed for debugging and monitoring. +For privacy and security reasons, you may want to monitor your agent's behavior and performance without exposing sensitive user data or proprietary prompts in your observability platform. Pydantic AI allows you to exclude the actual content from instrumentation events while preserving the structural information needed for debugging and monitoring. -When `include_content=False` is set, PydanticAI will exclude sensitive content from OpenTelemetry events, including user prompts and model completions, tool call arguments and responses, and any other message content. +When `include_content=False` is set, Pydantic AI will exclude sensitive content from OpenTelemetry events, including user prompts and model completions, tool call arguments and responses, and any other message content. ```python {title="excluding_sensitive_content.py"} from pydantic_ai.agent import Agent diff --git a/docs/mcp/client.md b/docs/mcp/client.md index 15ef46f2e2..40fb6dbff7 100644 --- a/docs/mcp/client.md +++ b/docs/mcp/client.md @@ -1,6 +1,6 @@ # Client -PydanticAI can act as an [MCP client](https://modelcontextprotocol.io/quickstart/client), connecting to MCP servers +Pydantic AI can act as an [MCP client](https://modelcontextprotocol.io/quickstart/client), connecting to MCP servers to use their tools. ## Install @@ -16,7 +16,7 @@ pip/uv-add "pydantic-ai-slim[mcp]" ## Usage -PydanticAI comes with three ways to connect to MCP servers: +Pydantic AI comes with two ways to connect to MCP servers: - [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] which connects to an MCP server using the [Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport - [`MCPServerSSE`][pydantic_ai.mcp.MCPServerSSE] which connects to an MCP server using the [HTTP SSE](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) transport @@ -34,9 +34,7 @@ You can use the [`async with agent`][pydantic_ai.Agent.__aenter__] context manag [Streamable HTTP](https://modelcontextprotocol.io/introduction#streamable-http) transport to a server. !!! note - [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be - running and accepting HTTP connections before running the agent. Running the server is not - managed by Pydantic AI. + [`MCPServerStreamableHTTP`][pydantic_ai.mcp.MCPServerStreamableHTTP] requires an MCP server to be running and accepting HTTP connections before running the agent. Running the server is not managed by Pydantic AI. Before creating the Streamable HTTP client, we need to run a server that supports the Streamable HTTP transport. @@ -80,7 +78,7 @@ _(This example is complete, it can be run "as is" with Python 3.10+ — you'll n - The model is receiving the prompt "how many days between 2000-01-01 and 2025-03-18?" - The model decides "Oh, I've got this `run_python_code` tool, that will be a good way to answer this question", and writes some python code to calculate the answer. - The model returns a tool call -- PydanticAI sends the tool call to the MCP server using the SSE transport +- Pydantic AI sends the tool call to the MCP server using the SSE transport - The model is called again with the return value of running the code - The model returns the final answer diff --git a/docs/mcp/index.md b/docs/mcp/index.md index 01dfe353e3..ffaa9a7857 100644 --- a/docs/mcp/index.md +++ b/docs/mcp/index.md @@ -1,14 +1,14 @@ # Model Context Protocol (MCP) -PydanticAI supports [Model Context Protocol (MCP)](https://modelcontextprotocol.io) in three ways: +Pydantic AI supports [Model Context Protocol (MCP)](https://modelcontextprotocol.io) in three ways: 1. [Agents](../agents.md) act as an MCP Client, connecting to MCP servers to use their tools, [learn more …](client.md) 2. Agents can be used within MCP servers, [learn more …](server.md) -3. As part of PydanticAI, we're building a number of MCP servers, [see below](#mcp-servers) +3. As part of Pydantic AI, we're building a number of MCP servers, [see below](#mcp-servers) ## What is MCP? -The Model Context Protocol is a standardized protocol that allow AI applications (including programmatic agents like PydanticAI, coding agents like [cursor](https://www.cursor.com/), and desktop applications like [Claude Desktop](https://claude.ai/download)) to connect to external tools and services using a common interface. +The Model Context Protocol is a standardized protocol that allow AI applications (including programmatic agents like Pydantic AI, coding agents like [cursor](https://www.cursor.com/), and desktop applications like [Claude Desktop](https://claude.ai/download)) to connect to external tools and services using a common interface. As with other protocols, the dream of MCP is that a wide range of applications can speak to each other without the need for specific integrations. @@ -16,14 +16,14 @@ There is a great list of MCP servers at [github.com/modelcontextprotocol/servers Some examples of what this means: -* PydanticAI could use a web search service implemented as an MCP server to implement a deep research agent -* Cursor could connect to the [Pydantic Logfire](https://github.com/pydantic/logfire-mcp) MCP server to search logs, traces and metrics to gain context while fixing a bug -* PydanticAI, or any other MCP client could connect to our [Run Python](run-python.md) MCP server to run arbitrary Python code in a sandboxed environment +- Pydantic AI could use a web search service implemented as an MCP server to implement a deep research agent +- Cursor could connect to the [Pydantic Logfire](https://github.com/pydantic/logfire-mcp) MCP server to search logs, traces and metrics to gain context while fixing a bug +- Pydantic AI, or any other MCP client could connect to our [Run Python](run-python.md) MCP server to run arbitrary Python code in a sandboxed environment ## MCP Servers -To add functionality to PydanticAI while making it as widely usable as possible, we're implementing some functionality as MCP servers. +To add functionality to Pydantic AI while making it as widely usable as possible, we're implementing some functionality as MCP servers. -So far, we've only implemented one MCP server as part of PydanticAI: +So far, we've only implemented one MCP server as part of Pydantic AI: -* [Run Python](run-python.md): A sandboxed Python interpreter that can run arbitrary code, with a focus on security and safety. +- [Run Python](run-python.md): A sandboxed Python interpreter that can run arbitrary code, with a focus on security and safety. diff --git a/docs/mcp/run-python.md b/docs/mcp/run-python.md index 1eb0dd773d..f99a159827 100644 --- a/docs/mcp/run-python.md +++ b/docs/mcp/run-python.md @@ -41,11 +41,11 @@ where: standard library. This is also useful to check the server is running correctly. -Usage of `jsr:@pydantic/mcp-run-python` with PydanticAI is described in the [client](client.md#mcp-stdio-server) documentation. +Usage of `jsr:@pydantic/mcp-run-python` with Pydantic AI is described in the [client](client.md#mcp-stdio-server) documentation. ## Direct Usage -As well as using this server with PydanticAI, it can be connected to other MCP clients. For clarity, in this example we connect directly using the [Python MCP client](https://github.com/modelcontextprotocol/python-sdk). +As well as using this server with Pydantic AI, it can be connected to other MCP clients. For clarity, in this example we connect directly using the [Python MCP client](https://github.com/modelcontextprotocol/python-sdk). ```python {title="mcp_run_python.py" py="3.10"} from mcp import ClientSession, StdioServerParameters diff --git a/docs/mcp/server.md b/docs/mcp/server.md index 9c0fda72f4..6ded02342c 100644 --- a/docs/mcp/server.md +++ b/docs/mcp/server.md @@ -1,17 +1,17 @@ # Server -PydanticAI agents can also be used within MCP Servers. +Pydantic AI models can also be used within MCP Servers. ## MCP Server -Here's a simple example of a [Python MCP server](https://github.com/modelcontextprotocol/python-sdk) using PydanticAI within a tool call: +Here's a simple example of a [Python MCP server](https://github.com/modelcontextprotocol/python-sdk) using Pydantic AI within a tool call: ```py {title="mcp_server.py" py="3.10"} from mcp.server.fastmcp import FastMCP from pydantic_ai import Agent -server = FastMCP('PydanticAI Server') +server = FastMCP('Pydantic AI Server') server_agent = Agent( 'anthropic:claude-3-5-haiku-latest', system_prompt='always reply in rhyme' ) @@ -76,7 +76,7 @@ from mcp.server.fastmcp import Context, FastMCP from pydantic_ai import Agent from pydantic_ai.models.mcp_sampling import MCPSamplingModel -server = FastMCP('PydanticAI Server with sampling') +server = FastMCP('Pydantic AI Server with sampling') server_agent = Agent(system_prompt='always reply in rhyme') diff --git a/docs/message-history.md b/docs/message-history.md index 179c4c291c..3e8cedbd03 100644 --- a/docs/message-history.md +++ b/docs/message-history.md @@ -1,6 +1,6 @@ # Messages and chat history -PydanticAI provides access to messages exchanged during an agent run. These messages can be used both to continue a coherent conversation, and to understand how an agent performed. +Pydantic AI provides access to messages exchanged during an agent run. These messages can be used both to continue a coherent conversation, and to understand how an agent performed. ### Accessing Messages from Results @@ -10,8 +10,8 @@ Both [`RunResult`][pydantic_ai.agent.AgentRunResult] (returned by [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync]) and [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] (returned by [`Agent.run_stream`][pydantic_ai.Agent.run_stream]) have the following methods: -* [`all_messages()`][pydantic_ai.agent.AgentRunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.agent.AgentRunResult.all_messages_json]. -* [`new_messages()`][pydantic_ai.agent.AgentRunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.agent.AgentRunResult.new_messages_json]. +- [`all_messages()`][pydantic_ai.agent.AgentRunResult.all_messages]: returns all messages, including messages from prior runs. There's also a variant that returns JSON bytes, [`all_messages_json()`][pydantic_ai.agent.AgentRunResult.all_messages_json]. +- [`new_messages()`][pydantic_ai.agent.AgentRunResult.new_messages]: returns only the messages from the current run. There's also a variant that returns JSON bytes, [`new_messages_json()`][pydantic_ai.agent.AgentRunResult.new_messages_json]. !!! info "StreamedRunResult and complete messages" On [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult], the messages returned from these methods will only include the final result message once the stream has finished. @@ -65,6 +65,7 @@ print(result.all_messages()) ] """ ``` + _(This example is complete, it can be run "as is")_ Example of accessing methods on a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult] : @@ -132,11 +133,12 @@ async def main(): ] """ ``` + _(This example is complete, it can be run "as is" — you'll need to add `asyncio.run(main())` to run `main`)_ ### Using Messages as Input for Further Agent Runs -The primary use of message histories in PydanticAI is to maintain context across multiple agent runs. +The primary use of message histories in Pydantic AI is to maintain context across multiple agent runs. To use existing messages in a run, pass them to the `message_history` parameter of [`Agent.run`][pydantic_ai.Agent.run], [`Agent.run_sync`][pydantic_ai.Agent.run_sync] or @@ -203,6 +205,7 @@ print(result2.all_messages()) ] """ ``` + _(This example is complete, it can be run "as is")_ ## Storing and loading messages (to JSON) @@ -328,7 +331,7 @@ Sometimes you may want to modify the message history before it's sent to the mod reasons (filtering out sensitive information), to save costs on tokens, to give less context to the LLM, or custom processing logic. -PydanticAI provides a `history_processors` parameter on `Agent` that allows you to intercept and modify +Pydantic AI provides a `history_processors` parameter on `Agent` that allows you to intercept and modify the message history before each model request. ### Usage diff --git a/docs/models/gemini.md b/docs/models/gemini.md index 850fea2af6..6e8e02215a 100644 --- a/docs/models/gemini.md +++ b/docs/models/gemini.md @@ -10,7 +10,7 @@ Check it out [here](../api/models/google.md). -PydanticAI supports Google's Gemini models through two different APIs: +Pydantic AI supports Google's Gemini models through two different APIs: - Generative Language API (`generativelanguage.googleapis.com`) - Vertex AI API (`*-aiplatform.googleapis.com`) @@ -115,7 +115,7 @@ This interface has a number of advantages over `generativelanguage.googleapis.co 1. The VertexAI API comes with more enterprise readiness guarantees. 2. You can [purchase provisioned throughput](https://cloud.google.com/vertex-ai/generative-ai/docs/provisioned-throughput#purchase-provisioned-throughput) with VertexAI to guarantee capacity. -3. If you're running PydanticAI inside GCP, you don't need to set up authentication, it should "just work". +3. If you're running Pydantic AI inside GCP, you don't need to set up authentication, it should "just work". 4. You can decide which region to use, which might be important from a regulatory perspective, and might improve latency. The big disadvantage is that for local development you may need to create and configure a "service account", which can be challenging to get right. @@ -124,7 +124,7 @@ Whichever way you authenticate, you'll need to have VertexAI enabled in your GCP ### Application default credentials -Luckily if you're running PydanticAI inside GCP, or you have the [`gcloud` CLI](https://cloud.google.com/sdk/gcloud) installed and configured, you should be able to use `VertexAIModel` without any additional setup. +Luckily if you're running Pydantic AI inside GCP, or you have the [`gcloud` CLI](https://cloud.google.com/sdk/gcloud) installed and configured, you should be able to use `VertexAIModel` without any additional setup. To use `VertexAIModel`, with [application default credentials](https://cloud.google.com/docs/authentication/application-default-credentials) configured (e.g. with `gcloud`), you can simply use: diff --git a/docs/models/index.md b/docs/models/index.md index c384d30a7d..8527b95819 100644 --- a/docs/models/index.md +++ b/docs/models/index.md @@ -1,6 +1,6 @@ # Model Providers -PydanticAI is model-agnostic and has built-in support for multiple model providers: +Pydantic AI is model-agnostic and has built-in support for multiple model providers: * [OpenAI](openai.md) * [Anthropic](anthropic.md) @@ -13,20 +13,20 @@ PydanticAI is model-agnostic and has built-in support for multiple model provide ## OpenAI-compatible Providers -In addition, many providers are compatible with the OpenAI API, and can be used with `OpenAIModel` in PydanticAI: +In addition, many providers are compatible with the OpenAI API, and can be used with `OpenAIModel` in Pydantic AI: -* [DeepSeek](openai.md#deepseek) -* [Grok (xAI)](openai.md#grok-xai) -* [Ollama](openai.md#ollama) -* [OpenRouter](openai.md#openrouter) -* [Perplexity](openai.md#perplexity) -* [Fireworks AI](openai.md#fireworks-ai) -* [Together AI](openai.md#together-ai) -* [Azure AI Foundry](openai.md#azure-ai-foundry) -* [Heroku](openai.md#heroku-ai) -* [GitHub Models](openai.md#github-models) +- [DeepSeek](openai.md#deepseek) +- [Grok (xAI)](openai.md#grok-xai) +- [Ollama](openai.md#ollama) +- [OpenRouter](openai.md#openrouter) +- [Perplexity](openai.md#perplexity) +- [Fireworks AI](openai.md#fireworks-ai) +- [Together AI](openai.md#together-ai) +- [Azure AI Foundry](openai.md#azure-ai-foundry) +- [Heroku](openai.md#heroku-ai) +- [GitHub Models](openai.md#github-models) -PydanticAI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md) +Pydantic AI also comes with [`TestModel`](../api/models/test.md) and [`FunctionModel`](../api/models/function.md) for testing and development. To use each model provider, you need to configure your local environment and make sure you have the right @@ -34,29 +34,29 @@ packages installed. ## Models and Providers -PydanticAI uses a few key terms to describe how it interacts with different LLMs: - -* **Model**: This refers to the PydanticAI class used to make requests following a specific LLM API - (generally by wrapping a vendor-provided SDK, like the `openai` python SDK). These classes implement a - vendor-SDK-agnostic API, ensuring a single PydanticAI agent is portable to different LLM vendors without - any other code changes just by swapping out the Model it uses. Model classes are named - roughly in the format `Model`, for example, we have `OpenAIModel`, `AnthropicModel`, `GeminiModel`, - etc. When using a Model class, you specify the actual LLM model name (e.g., `gpt-4o`, - `claude-3-5-sonnet-latest`, `gemini-1.5-flash`) as a parameter. -* **Provider**: This refers to provider-specific classes which handle the authentication and connections - to an LLM vendor. Passing a non-default _Provider_ as a parameter to a Model is how you can ensure - that your agent will make requests to a specific endpoint, or make use of a specific approach to - authentication (e.g., you can use Vertex-specific auth with the `GeminiModel` by way of the `VertexProvider`). - In particular, this is how you can make use of an AI gateway, or an LLM vendor that offers API compatibility - with the vendor SDK used by an existing Model (such as `OpenAIModel`). -* **Profile**: This refers to a description of how requests to a specific model or family of models need to be - constructed to get the best results, independent of the model and provider classes used. - For example, different models have different restrictions on the JSON schemas that can be used for tools, - and the same schema transformer needs to be used for Gemini models whether you're using `GoogleModel` - with model name `gemini-2.5-pro-preview`, or `OpenAIModel` with `OpenRouterProvider` and model name `google/gemini-2.5-pro-preview`. +Pydantic AI uses a few key terms to describe how it interacts with different LLMs: + +- **Model**: This refers to the Pydantic AI class used to make requests following a specific LLM API + (generally by wrapping a vendor-provided SDK, like the `openai` python SDK). These classes implement a + vendor-SDK-agnostic API, ensuring a single Pydantic AI agent is portable to different LLM vendors without + any other code changes just by swapping out the Model it uses. Model classes are named + roughly in the format `Model`, for example, we have `OpenAIModel`, `AnthropicModel`, `GeminiModel`, + etc. When using a Model class, you specify the actual LLM model name (e.g., `gpt-4o`, + `claude-3-5-sonnet-latest`, `gemini-1.5-flash`) as a parameter. +- **Provider**: This refers to provider-specific classes which handle the authentication and connections + to an LLM vendor. Passing a non-default _Provider_ as a parameter to a Model is how you can ensure + that your agent will make requests to a specific endpoint, or make use of a specific approach to + authentication (e.g., you can use Vertex-specific auth with the `GeminiModel` by way of the `VertexProvider`). + In particular, this is how you can make use of an AI gateway, or an LLM vendor that offers API compatibility + with the vendor SDK used by an existing Model (such as `OpenAIModel`). +- **Profile**: This refers to a description of how requests to a specific model or family of models need to be + constructed to get the best results, independent of the model and provider classes used. + For example, different models have different restrictions on the JSON schemas that can be used for tools, + and the same schema transformer needs to be used for Gemini models whether you're using `GoogleModel` + with model name `gemini-2.5-pro-preview`, or `OpenAIModel` with `OpenRouterProvider` and model name `google/gemini-2.5-pro-preview`. When you instantiate an [`Agent`][pydantic_ai.Agent] with just a name formatted as `:`, e.g. `openai:gpt-4o` or `openrouter:google/gemini-2.5-pro-preview`, -PydanticAI will automatically select the appropriate model class, provider, and profile. +Pydantic AI will automatically select the appropriate model class, provider, and profile. If you want to use a different provider or profile, you can instantiate a model class directly and pass in `provider` and/or `profile` arguments. ## Custom Models @@ -66,21 +66,23 @@ For streaming, you'll also need to implement the [`StreamedResponse`][pydantic_a The best place to start is to review the source code for existing implementations, e.g. [`OpenAIModel`](https://github.com/pydantic/pydantic-ai/blob/main/pydantic_ai_slim/pydantic_ai/models/openai.py). -For details on when we'll accept contributions adding new models to PydanticAI, see the [contributing guidelines](../contributing.md#new-model-rules). +For details on when we'll accept contributions adding new models to Pydantic AI, see the [contributing guidelines](../contributing.md#new-model-rules). If a model API is compatible with the OpenAI API, you do not need a custom model class and can provide your own [custom provider](openai.md#openai-compatible-models) instead. + ## Fallback Model You can use [`FallbackModel`][pydantic_ai.models.fallback.FallbackModel] to attempt multiple models -in sequence until one successfully returns a result. Under the hood, PydanticAI automatically switches +in sequence until one successfully returns a result. Under the hood, Pydantic AI automatically switches from one model to the next if the current model returns a 4xx or 5xx status code. In the following example, the agent first makes a request to the OpenAI model (which fails due to an invalid API key), and then falls back to the Anthropic model. + ```python {title="fallback_model.py" test="skip"} from pydantic_ai import Agent from pydantic_ai.models.anthropic import AnthropicModel diff --git a/docs/models/openai.md b/docs/models/openai.md index 3c2fa9153f..bd68c95c0b 100644 --- a/docs/models/openai.md +++ b/docs/models/openai.md @@ -100,7 +100,7 @@ agent = Agent(model) ## OpenAI Responses API -PydanticAI also supports OpenAI's [Responses API](https://platform.openai.com/docs/api-reference/responses) through the `OpenAIResponsesModel` class. +Pydantic AI also supports OpenAI's [Responses API](https://platform.openai.com/docs/api-reference/responses) through the `OpenAIResponsesModel` class. ```python from pydantic_ai import Agent @@ -144,7 +144,7 @@ You can learn more about the differences between the Responses API and Chat Comp ## OpenAI-compatible Models -Many providers and models are compatible with the OpenAI API, and can be used with `OpenAIModel` in PydanticAI. +Many providers and models are compatible with the OpenAI API, and can be used with `OpenAIModel` in Pydantic AI. Before getting started, check the [installation and configuration](#install) instructions above. To use another OpenAI-compatible API, you can make use of the `base_url` and `api_key` arguments from `OpenAIProvider`: @@ -171,7 +171,7 @@ When a provider has its own provider class, you can use the `Agent(":< Sometimes, the provider or model you're using will have slightly different requirements than OpenAI's API or models, like having different restrictions on JSON schemas for tool definitions, or not supporting tool definitions to be marked as strict. -When using an alternative provider class provided by PydanticAI, an appropriate model profile is typically selected automatically based on the model name. +When using an alternative provider class provided by Pydantic AI, an appropriate model profile is typically selected automatically based on the model name. If the model you're using is not working correctly out of the box, you can tweak various aspects of how model requests are constructed by providing your own [`ModelProfile`][pydantic_ai.profiles.ModelProfile] (for behaviors shared among all model classes) or [`OpenAIModelProfile`][pydantic_ai.profiles.openai.OpenAIModelProfile] (for behaviors specific to `OpenAIModel`): ```py diff --git a/docs/multi-agent-applications.md b/docs/multi-agent-applications.md index 7a2ce36f37..a97cc98132 100644 --- a/docs/multi-agent-applications.md +++ b/docs/multi-agent-applications.md @@ -1,6 +1,6 @@ # Multi-agent Applications -There are roughly four levels of complexity when building applications with PydanticAI: +There are roughly four levels of complexity when building applications with Pydantic AI: 1. Single agent workflows — what most of the `pydantic_ai` documentation covers 2. [Agent delegation](#agent-delegation) — agents using another agent via tools @@ -327,6 +327,6 @@ See the [graph](graph.md) documentation on when and how to use graphs. ## Examples -The following examples demonstrate how to use dependencies in PydanticAI: +The following examples demonstrate how to use dependencies in Pydantic AI: - [Flight booking](examples/flight-booking.md) diff --git a/docs/output.md b/docs/output.md index 2391a5efde..9e664a97d7 100644 --- a/docs/output.md +++ b/docs/output.md @@ -45,7 +45,7 @@ Structured outputs (like tools) use Pydantic to build the JSON schema used for t !!! note "Type checking considerations" The Agent class is generic in its output type, and this type is carried through to `AgentRunResult.output` and `StreamedRunResult.output` so that your IDE or static type checker can warn you when your code doesn't properly take into account all the possible values those outputs could have. - Static type checkers like pyright and mypy will do their best the infer the agent's output type from the `output_type` you've specified, but they're not always able to do so correctly when you provide functions or multiple types in a union or list, even though PydanticAI will behave correctly. When this happens, your type checker will complain even when you're confident you've passed a valid `output_type`, and you'll need to help the type checker by explicitly specifying the generic parameters on the `Agent` constructor. This is shown in the second example below and the output functions example further down. + Static type checkers like pyright and mypy will do their best the infer the agent's output type from the `output_type` you've specified, but they're not always able to do so correctly when you provide functions or multiple types in a union or list, even though Pydantic AI will behave correctly. When this happens, your type checker will complain even when you're confident you've passed a valid `output_type`, and you'll need to help the type checker by explicitly specifying the generic parameters on the `Agent` constructor. This is shown in the second example below and the output functions example further down. Specifically, there are three valid uses of `output_type` where you'll need to do this: @@ -421,7 +421,7 @@ result = agent.run_sync("Create a person") ### Output validators {#output-validator-functions} -Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. PydanticAI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. +Some validation is inconvenient or impossible to do in Pydantic validators, in particular when the validation requires IO and is asynchronous. Pydantic AI provides a way to add validation functions via the [`agent.output_validator`][pydantic_ai.Agent.output_validator] decorator. If you want to implement separate validation logic for different output types, it's recommended to use [output functions](#output-functions) instead, to save you from having to do `isinstance` checks inside the output validator. If you want the model to output plain text, do your own processing or validation, and then have the agent's final output be the result of your function, it's recommended to use an [output function](#output-functions) with the [`TextOutput` marker class](#text-output). @@ -480,7 +480,7 @@ _(This example is complete, it can be run "as is")_ There two main challenges with streamed results: 1. Validating structured responses before they're complete, this is achieved by "partial validation" which was recently added to Pydantic in [pydantic/pydantic#10748](https://github.com/pydantic/pydantic/pull/10748). -2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. PydanticAI streams just enough of the response to sniff out if it's a tool call or an output, then streams the whole thing and calls tools, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult]. +2. When receiving a response, we don't know if it's the final response without starting to stream it and peeking at the content. Pydantic AI streams just enough of the response to sniff out if it's a tool call or an output, then streams the whole thing and calls tools, or returns the stream as a [`StreamedRunResult`][pydantic_ai.result.StreamedRunResult]. ### Streaming Text @@ -629,7 +629,7 @@ _(This example is complete, it can be run "as is" — you'll need to add `asynci ## Examples -The following examples demonstrate how to use streamed responses in PydanticAI: +The following examples demonstrate how to use streamed responses in Pydantic AI: - [Stream markdown](examples/stream-markdown.md) - [Stream Whales](examples/stream-whales.md) diff --git a/docs/testing.md b/docs/testing.md index b40bb1dc9c..49b3eba3ca 100644 --- a/docs/testing.md +++ b/docs/testing.md @@ -1,17 +1,17 @@ # Unit testing -Writing unit tests for PydanticAI code is just like unit tests for any other Python code. +Writing unit tests for Pydantic AI code is just like unit tests for any other Python code. Because for the most part they're nothing new, we have pretty well established tools and patterns for writing and running these kinds of tests. Unless you're really sure you know better, you'll probably want to follow roughly this strategy: -* Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness -* If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) -* Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures -* Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the usage, latency and variability of real LLM calls -* Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace an agent's model, dependencies, or toolsets inside your application logic -* Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally +- Use [`pytest`](https://docs.pytest.org/en/stable/) as your test harness +- If you find yourself typing out long assertions, use [inline-snapshot](https://15r10nk.github.io/inline-snapshot/latest/) +- Similarly, [dirty-equals](https://dirty-equals.helpmanual.io/latest/) can be useful for comparing large data structures +- Use [`TestModel`][pydantic_ai.models.test.TestModel] or [`FunctionModel`][pydantic_ai.models.function.FunctionModel] in place of your actual model to avoid the usage, latency and variability of real LLM calls +- Use [`Agent.override`][pydantic_ai.agent.Agent.override] to replace an agent's model, dependencies, or toolsets inside your application logic +- Set [`ALLOW_MODEL_REQUESTS=False`][pydantic_ai.models.ALLOW_MODEL_REQUESTS] globally to block any requests from being made to non-test models accidentally ### Unit testing with `TestModel` diff --git a/docs/thinking.md b/docs/thinking.md index 8bdd0b898e..6d88ac5d1b 100644 --- a/docs/thinking.md +++ b/docs/thinking.md @@ -6,7 +6,7 @@ providing its final answer. This capability is typically disabled by default and depends on the specific model being used. See the sections below for how to enable thinking for each provider. -Internally, if the model doesn't provide thinking objects, PydanticAI will convert thinking blocks +Internally, if the model doesn't provide thinking objects, Pydantic AI will convert thinking blocks (`"...""`) in provider-specific text parts to `ThinkingPart`s. We have also made the decision not to send `ThinkingPart`s back to the provider in multi-turn conversations - this helps save costs for users. In the future, we plan to add a setting to customize this behavior. @@ -14,7 +14,7 @@ this helps save costs for users. In the future, we plan to add a setting to cust ## OpenAI When using the [`OpenAIModel`][pydantic_ai.models.openai.OpenAIModel], thinking objects are not created -by default. However, the text content may contain `""` tags. When this happens, PydanticAI will +by default. However, the text content may contain `""` tags. When this happens, Pydantic AI will convert them to [`ThinkingPart`][pydantic_ai.messages.ThinkingPart] objects. In contrast, the [`OpenAIResponsesModel`][pydantic_ai.models.openai.OpenAIResponsesModel] does diff --git a/docs/tools.md b/docs/tools.md index 134a8f96ea..6744b3cd13 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -8,9 +8,9 @@ If you want a model to be able to call a function as its final action, without t There are a number of ways to register tools with an agent: -* via the [`@agent.tool`][pydantic_ai.Agent.tool] decorator — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] -* via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] -* via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] +- via the [`@agent.tool`][pydantic_ai.Agent.tool] decorator — for tools that need access to the agent [context][pydantic_ai.tools.RunContext] +- via the [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator — for tools that do not need access to the agent [context][pydantic_ai.tools.RunContext] +- via the [`tools`][pydantic_ai.Agent.__init__] keyword argument to `Agent` which can take either plain functions, or instances of [`Tool`][pydantic_ai.tools.Tool] For more advanced use cases, the [toolsets](toolsets.md) feature lets you manage collections of tools (built by you or providd by an [MCP server](mcp/client.md) or other [third party](#third-party-tools)) and register them with an agent in one go via the [`toolsets`][pydantic_ai.Agent.__init__] keyword argument to `Agent`. @@ -294,6 +294,7 @@ result = agent.run_sync('What is the main content of the document?') print(result.output) #> The document contains just the text "Dummy PDF file." ``` + _(This example is complete, it can be run "as is")_ Some models (e.g. Gemini) natively support semi-structured return values, while some expect text (OpenAI) but seem to be just as good at extracting meaning from the data. If a Python object is returned and the model expects a string, the value will be serialized to JSON. @@ -362,9 +363,9 @@ This separation allows you to provide rich context to the model while maintainin Function parameters are extracted from the function signature, and all parameters except `RunContext` are used to build the schema for that tool call. -Even better, PydanticAI extracts the docstring from functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. +Even better, Pydantic AI extracts the docstring from functions and (thanks to [griffe](https://mkdocstrings.github.io/griffe/)) extracts parameter descriptions from the docstring and adds them to the schema. -[Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy`, and `sphinx` style docstrings. PydanticAI will infer the format to use based on the docstring, but you can explicitly set it using [`docstring_format`][pydantic_ai.tools.DocstringFormat]. You can also enforce parameter requirements by setting `require_parameter_descriptions=True`. This will raise a [`UserError`][pydantic_ai.exceptions.UserError] if a parameter description is missing. +[Griffe supports](https://mkdocstrings.github.io/griffe/reference/docstrings/#docstrings) extracting parameter descriptions from `google`, `numpy`, and `sphinx` style docstrings. Pydantic AI will infer the format to use based on the docstring, but you can explicitly set it using [`docstring_format`][pydantic_ai.tools.DocstringFormat]. You can also enforce parameter requirements by setting `require_parameter_descriptions=True`. This will raise a [`UserError`][pydantic_ai.exceptions.UserError] if a parameter description is missing. To demonstrate a tool's schema, here we use [`FunctionModel`][pydantic_ai.models.function.FunctionModel] to print the schema a model would receive: @@ -472,7 +473,7 @@ _(This example is complete, it can be run "as is")_ ### Custom Tool Schema -If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of *args or **kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the [`Tool.from_schema`][pydantic_ai.Tool.from_schema] function. With this you provide the name, description and JSON schema for the function directly: +If you have a function that lacks appropriate documentation (i.e. poorly named, no type information, poor docstring, use of \*args or \*\*kwargs and suchlike) then you can still turn it into a tool that can be effectively used by the agent with the [`Tool.from_schema`][pydantic_ai.Tool.from_schema] function. With this you provide the name, description and JSON schema for the function directly: ```python from pydantic_ai import Agent, Tool @@ -505,7 +506,6 @@ print(result.output) #> {"sum":0} ``` - Please note that validation of the tool arguments will not be performed, and this will pass all arguments as keyword arguments. ## Dynamic Tools {#tool-prepare} @@ -515,9 +515,9 @@ customize the definition of the tool passed to the model, or omit the tool compl A `prepare` method can be registered via the `prepare` kwarg to any of the tool registration mechanisms: -* [`@agent.tool`][pydantic_ai.Agent.tool] decorator -* [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator -* [`Tool`][pydantic_ai.tools.Tool] dataclass +- [`@agent.tool`][pydantic_ai.Agent.tool] decorator +- [`@agent.tool_plain`][pydantic_ai.Agent.tool_plain] decorator +- [`Tool`][pydantic_ai.tools.Tool] dataclass The `prepare` method, should be of type [`ToolPrepareFunc`][pydantic_ai.tools.ToolPrepareFunc], a function which takes [`RunContext`][pydantic_ai.tools.RunContext] and a pre-built [`ToolDefinition`][pydantic_ai.tools.ToolDefinition], and should either return that `ToolDefinition` with or without modifying it, return a new `ToolDefinition`, or return `None` to indicate this tools should not be registered for that step. @@ -617,7 +617,7 @@ The `prepare_tools` function should be of type [`ToolsPrepareFunc`][pydantic_ai. !!! note The list of tool definitions passed to `prepare_tools` includes both regular function tools and tools from any [toolsets](toolsets.md) registered to the agent, but not [output tools](output.md#tool-output). - To modify output tools, you can set a `prepare_output_tools` function instead. +To modify output tools, you can set a `prepare_output_tools` function instead. Here's an example that makes all tools strict if the model is an OpenAI model: @@ -704,7 +704,6 @@ You can use `prepare_tools` to: If both per-tool `prepare` and agent-wide `prepare_tools` are used, the per-tool `prepare` is applied first to each tool, and then `prepare_tools` is called with the resulting list of tool definitions. - ## Tool Execution and Retries {#tool-retries} When a tool is executed, its arguments (provided by the LLM) are first validated against the function's signature using Pydantic. If validation fails (e.g., due to incorrect types or missing required arguments), a `ValidationError` is raised, and the framework automatically generates a [`RetryPromptPart`][pydantic_ai.messages.RetryPromptPart] containing the validation details. This prompt is sent back to the LLM, informing it of the error and allowing it to correct the parameters and retry the tool call. @@ -722,6 +721,7 @@ def my_flaky_tool(query: str) -> str: # ... process query ... return 'Success!' ``` + Raising `ModelRetry` also generates a `RetryPromptPart` containing the exception message, which is sent back to the LLM to guide its next attempt. Both `ValidationError` and `ModelRetry` respect the `retries` setting configured on the `Tool` or `Agent`. ## Third-Party Tools diff --git a/docs/troubleshooting.md b/docs/troubleshooting.md index 7f73d0fb43..ae9e350baf 100644 --- a/docs/troubleshooting.md +++ b/docs/troubleshooting.md @@ -1,17 +1,19 @@ # Troubleshooting -Below are suggestions on how to fix some common errors you might encounter while using PydanticAI. If the issue you're experiencing is not listed below or addressed in the documentation, please feel free to ask in the [Pydantic Slack](help.md) or create an issue on [GitHub](https://github.com/pydantic/pydantic-ai/issues). +Below are suggestions on how to fix some common errors you might encounter while using Pydantic AI. If the issue you're experiencing is not listed below or addressed in the documentation, please feel free to ask in the [Pydantic Slack](help.md) or create an issue on [GitHub](https://github.com/pydantic/pydantic-ai/issues). ## Jupyter Notebook Errors ### `RuntimeError: This event loop is already running` -This error is caused by conflicts between the event loops in Jupyter notebook and PydanticAI's. One way to manage these conflicts is by using [`nest-asyncio`](https://pypi.org/project/nest-asyncio/). Namely, before you execute any agent runs, do the following: +This error is caused by conflicts between the event loops in Jupyter notebook and Pydantic AI's. One way to manage these conflicts is by using [`nest-asyncio`](https://pypi.org/project/nest-asyncio/). Namely, before you execute any agent runs, do the following: + ```python {test="skip"} import nest_asyncio nest_asyncio.apply() ``` + Note: This fix also applies to Google Colab and [Marimo](https://github.com/marimo-team/marimo). ## API Key Configuration diff --git a/examples/README.md b/examples/README.md index dbc7e53c86..f9818c9846 100644 --- a/examples/README.md +++ b/examples/README.md @@ -1,4 +1,4 @@ -# PydanticAI Examples +# Pydantic AI Examples [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) @@ -6,6 +6,6 @@ [![versions](https://img.shields.io/pypi/pyversions/pydantic-ai.svg)](https://github.com/pydantic/pydantic-ai) [![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) -Examples of how to use PydanticAI and what it can do. +Examples of how to use Pydantic AI and what it can do. For full documentation of these examples and how to run them, see [ai.pydantic.dev/examples/](https://ai.pydantic.dev/examples/). diff --git a/examples/pydantic_ai_examples/bank_support.py b/examples/pydantic_ai_examples/bank_support.py index d7fc74a4a6..ed73e8f484 100644 --- a/examples/pydantic_ai_examples/bank_support.py +++ b/examples/pydantic_ai_examples/bank_support.py @@ -1,4 +1,4 @@ -"""Small but complete example of using PydanticAI to build a support agent for a bank. +"""Small but complete example of using Pydantic AI to build a support agent for a bank. Run with: diff --git a/examples/pydantic_ai_examples/pydantic_model.py b/examples/pydantic_ai_examples/pydantic_model.py index 2ad754a32b..7a1defbb33 100644 --- a/examples/pydantic_ai_examples/pydantic_model.py +++ b/examples/pydantic_ai_examples/pydantic_model.py @@ -1,4 +1,4 @@ -"""Simple example of using PydanticAI to construct a Pydantic model from a text input. +"""Simple example of using Pydantic AI to construct a Pydantic model from a text input. Run with: diff --git a/examples/pydantic_ai_examples/roulette_wheel.py b/examples/pydantic_ai_examples/roulette_wheel.py index 7df3229d23..a1465581b0 100644 --- a/examples/pydantic_ai_examples/roulette_wheel.py +++ b/examples/pydantic_ai_examples/roulette_wheel.py @@ -1,4 +1,4 @@ -"""Example demonstrating how to use PydanticAI to create a simple roulette game. +"""Example demonstrating how to use Pydantic AI to create a simple roulette game. Run with: uv run -m pydantic_ai_examples.roulette_wheel diff --git a/examples/pydantic_ai_examples/sql_gen.py b/examples/pydantic_ai_examples/sql_gen.py index fdf8c5ff3d..b0010c11c4 100644 --- a/examples/pydantic_ai_examples/sql_gen.py +++ b/examples/pydantic_ai_examples/sql_gen.py @@ -1,4 +1,4 @@ -"""Example demonstrating how to use PydanticAI to generate SQL queries based on user input. +"""Example demonstrating how to use Pydantic AI to generate SQL queries based on user input. Run postgres with: diff --git a/examples/pydantic_ai_examples/weather_agent.py b/examples/pydantic_ai_examples/weather_agent.py index e1794342e5..f02cf854f6 100644 --- a/examples/pydantic_ai_examples/weather_agent.py +++ b/examples/pydantic_ai_examples/weather_agent.py @@ -1,4 +1,4 @@ -"""Example of PydanticAI with multiple tools which the LLM needs to call in turn to answer a question. +"""Example of Pydantic AI with multiple tools which the LLM needs to call in turn to answer a question. In this case the idea is a "weather" agent — the user can ask for the weather in multiple cities, the agent will use the `get_lat_lng` tool to get the latitude and longitude of the locations, then use diff --git a/examples/pyproject.toml b/examples/pyproject.toml index 770c9ba3b2..29720813b9 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -13,7 +13,7 @@ bump = true [project] name = "pydantic-ai-examples" dynamic = ["version", "dependencies"] -description = "Examples of how to use PydanticAI and what it can do." +description = "Examples of how to use Pydantic AI and what it can do." authors = [ { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, diff --git a/mcp-run-python/README.md b/mcp-run-python/README.md index edd84ddb88..4492f1eb76 100644 --- a/mcp-run-python/README.md +++ b/mcp-run-python/README.md @@ -30,7 +30,7 @@ where: - `warmup` will run a minimal Python script to download and cache the Python standard library. This is also useful to check the server is running correctly. -Here's an example of using `@pydantic/mcp-run-python` with PydanticAI: +Here's an example of using `@pydantic/mcp-run-python` with Pydantic AI: ```python from pydantic_ai import Agent diff --git a/mkdocs.yml b/mkdocs.yml index fc6cd27999..796d6601f5 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -1,4 +1,4 @@ -site_name: PydanticAI +site_name: Pydantic AI site_description: Agent Framework / shim to use Pydantic with LLMs strict: true site_url: https://ai.pydantic.dev @@ -114,29 +114,29 @@ extra: generator: false theme: - name: "material" + name: 'material' custom_dir: docs/.overrides palette: - - media: "(prefers-color-scheme)" + - media: '(prefers-color-scheme)' primary: pink accent: pink toggle: icon: material/brightness-auto - name: "Switch to light mode" - - media: "(prefers-color-scheme: light)" + name: 'Switch to light mode' + - media: '(prefers-color-scheme: light)' scheme: default primary: pink accent: pink toggle: icon: material/brightness-7 - name: "Switch to dark mode" - - media: "(prefers-color-scheme: dark)" + name: 'Switch to dark mode' + - media: '(prefers-color-scheme: dark)' scheme: slate primary: pink accent: pink toggle: icon: material/brightness-4 - name: "Switch to system preference" + name: 'Switch to system preference' features: - search.suggest - search.highlight @@ -149,8 +149,8 @@ theme: - navigation.sections - navigation.tracking - toc.follow - logo: "img/logo-white.svg" - favicon: "favicon.ico" + logo: 'img/logo-white.svg' + favicon: 'favicon.ico' # https://www.mkdocs.org/user-guide/configuration/#validation validation: @@ -160,13 +160,13 @@ validation: anchors: warn extra_css: - - "extra/tweaks.css" + - 'extra/tweaks.css' # used for analytics extra_javascript: - - "/flarelytics/client.js" - - "https://cdn.jsdelivr.net/npm/algoliasearch@5.20.0/dist/lite/builds/browser.umd.js" - - "https://cdn.jsdelivr.net/npm/instantsearch.js@4.77.3/dist/instantsearch.production.min.js" - - "/javascripts/algolia-search.js" + - '/flarelytics/client.js' + - 'https://cdn.jsdelivr.net/npm/algoliasearch@5.20.0/dist/lite/builds/browser.umd.js' + - 'https://cdn.jsdelivr.net/npm/instantsearch.js@4.77.3/dist/instantsearch.production.min.js' + - '/javascripts/algolia-search.js' markdown_extensions: - tables @@ -236,7 +236,7 @@ plugins: enabled: !ENV [CI, false] full_output: llms-full.txt markdown_description: |- - PydanticAI is a Python agent framework designed to make it less painful to build production grade + Pydantic AI is a Python agent framework designed to make it less painful to build production grade applications with Generative AI. sections: Concepts documentation: @@ -264,5 +264,5 @@ plugins: - examples/*.md hooks: - - "docs/.hooks/main.py" - - "docs/.hooks/algolia.py" + - 'docs/.hooks/main.py' + - 'docs/.hooks/algolia.py' diff --git a/pydantic_ai_slim/README.md b/pydantic_ai_slim/README.md index cdc9442471..bf2aefa82c 100644 --- a/pydantic_ai_slim/README.md +++ b/pydantic_ai_slim/README.md @@ -1,4 +1,4 @@ -# PydanticAI Slim +# Pydantic AI Slim [![CI](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml/badge.svg?event=push)](https://github.com/pydantic/pydantic-ai/actions/workflows/ci.yml?query=branch%3Amain) [![Coverage](https://coverage-badge.samuelcolvin.workers.dev/pydantic/pydantic-ai.svg)](https://coverage-badge.samuelcolvin.workers.dev/redirect/pydantic/pydantic-ai) @@ -6,6 +6,6 @@ [![versions](https://img.shields.io/pypi/pyversions/pydantic-ai-slim.svg)](https://github.com/pydantic/pydantic-ai) [![license](https://img.shields.io/github/license/pydantic/pydantic-ai.svg?v)](https://github.com/pydantic/pydantic-ai/blob/main/LICENSE) -PydanticAI core logic with minimal required dependencies. +Pydantic AI core logic with minimal required dependencies. For more information on how to use this package see [ai.pydantic.dev/install](https://ai.pydantic.dev/install/). diff --git a/pydantic_ai_slim/pydantic_ai/_cli.py b/pydantic_ai_slim/pydantic_ai/_cli.py index c4e3d63fca..ae4f6ff6fe 100644 --- a/pydantic_ai_slim/pydantic_ai/_cli.py +++ b/pydantic_ai_slim/pydantic_ai/_cli.py @@ -38,7 +38,7 @@ from rich.text import Text except ImportError as _import_error: raise ImportError( - 'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the PydanticAI CLI, ' + 'Please install `rich`, `prompt-toolkit` and `argcomplete` to use the Pydantic AI CLI, ' 'you can use the `cli` optional group — `pip install "pydantic-ai-slim[cli]"`' ) from _import_error @@ -47,7 +47,7 @@ PYDANTIC_AI_HOME = Path.home() / '.pydantic-ai' -"""The home directory for PydanticAI CLI. +"""The home directory for Pydantic AI CLI. This folder is used to store the prompt history and configuration. """ @@ -108,7 +108,7 @@ def cli( # noqa: C901 parser = argparse.ArgumentParser( prog=prog_name, description=f"""\ -PydanticAI CLI v{__version__}\n\n +Pydantic AI CLI v{__version__}\n\n Special prompts: * `/exit` - exit the interactive mode (ctrl-c and ctrl-d also work) @@ -153,7 +153,7 @@ def cli( # noqa: C901 args = parser.parse_args(args_list) console = Console() - name_version = f'[green]{prog_name} - PydanticAI CLI v{__version__}[/green]' + name_version = f'[green]{prog_name} - Pydantic AI CLI v{__version__}[/green]' if args.version: console.print(name_version, highlight=False) return 0 diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index f4961da513..459ec56822 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -2,10 +2,10 @@ The manager tracks which parts (in particular, text and tool calls) correspond to which vendor-specific identifiers (e.g., `index`, `tool_call_id`, etc., as appropriate for a given model), -and produces PydanticAI-format events as appropriate for consumers of the streaming APIs. +and produces Pydantic AI-format events as appropriate for consumers of the streaming APIs. The "vendor-specific identifiers" to use depend on the semantics of the responses of the responses from the vendor, -and are tightly coupled to the specific model being used, and the PydanticAI Model subclass implementation. +and are tightly coupled to the specific model being used, and the Pydantic AI Model subclass implementation. This `ModelResponsePartsManager` is used in each of the subclasses of `StreamedResponse` as a way to consolidate event-emitting logic. diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index e61f0e9a11..a43771d87b 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -521,7 +521,7 @@ class RetryPromptPart: tool_call_id: str = field(default_factory=_generate_tool_call_id) """The tool call identifier, this is used by some models including OpenAI. - In case the tool call id is not provided by the model, PydanticAI will generate a random one. + In case the tool call id is not provided by the model, Pydantic AI will generate a random one. """ timestamp: datetime = field(default_factory=_now_utc) @@ -562,12 +562,12 @@ def otel_event(self, settings: InstrumentationSettings) -> Event: ModelRequestPart = Annotated[ Union[SystemPromptPart, UserPromptPart, ToolReturnPart, RetryPromptPart], pydantic.Discriminator('part_kind') ] -"""A message part sent by PydanticAI to a model.""" +"""A message part sent by Pydantic AI to a model.""" @dataclass(repr=False) class ModelRequest: - """A request generated by PydanticAI and sent to a model, e.g. a message from the PydanticAI app to the model.""" + """A request generated by Pydantic AI and sent to a model, e.g. a message from the Pydantic AI app to the model.""" parts: list[ModelRequestPart] """The parts of the user message.""" @@ -645,7 +645,7 @@ class ToolCallPart: tool_call_id: str = field(default_factory=_generate_tool_call_id) """The tool call identifier, this is used by some models including OpenAI. - In case the tool call id is not provided by the model, PydanticAI will generate a random one. + In case the tool call id is not provided by the model, Pydantic AI will generate a random one. """ part_kind: Literal['tool-call'] = 'tool-call' @@ -693,7 +693,7 @@ def has_content(self) -> bool: @dataclass(repr=False) class ModelResponse: - """A response from a model, e.g. a message from the model to the PydanticAI app.""" + """A response from a model, e.g. a message from the model to the Pydantic AI app.""" parts: list[ModelResponsePart] """The parts of the model message.""" diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index 6aac2931af..a627415689 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -266,7 +266,7 @@ def _process_response(self, response: BetaMessage) -> ModelResponse: items.append(TextPart(content=item.text)) elif isinstance(item, BetaRedactedThinkingBlock): # pragma: no cover warnings.warn( - 'PydanticAI currently does not handle redacted thinking blocks. ' + 'Pydantic AI currently does not handle redacted thinking blocks. ' 'If you have a suggestion on how we should handle them, please open an issue.', UserWarning, ) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index 95d2969f49..9178d7dd43 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -43,7 +43,7 @@ def transform(self, schema: JsonSchema) -> JsonSchema: f' Full schema: {self.schema}\n\n' f'Source of additionalProperties within the full schema: {original_schema}\n\n' 'If this came from a field with a type like `dict[str, MyType]`, that field will always be empty.\n\n' - "If Google's APIs are updated to support this properly, please create an issue on the PydanticAI GitHub" + "If Google's APIs are updated to support this properly, please create an issue on the Pydantic AI GitHub" ' and we will fix this behavior.', UserWarning, ) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index 834b37fc89..bba0f241a3 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -13,7 +13,7 @@ class Usage: """LLM usage associated with a request or run. - Responsibility for calculating usage is on the model; PydanticAI simply sums the usage information across requests. + Responsibility for calculating usage is on the model; Pydantic AI simply sums the usage information across requests. You'll need to look up the documentation of the model you're using to convert usage to monetary costs. """ diff --git a/pydantic_evals/README.md b/pydantic_evals/README.md index eadfde663b..c7528ee500 100644 --- a/pydantic_evals/README.md +++ b/pydantic_evals/README.md @@ -9,18 +9,18 @@ This is a library for evaluating non-deterministic (or "stochastic") functions in Python. It provides a simple, Pythonic interface for defining and running stochastic functions, and analyzing the results of running those functions. -While this library is developed as part of [PydanticAI](https://ai.pydantic.dev), it only uses PydanticAI for a small +While this library is developed as part of [Pydantic AI](https://ai.pydantic.dev), it only uses Pydantic AI for a small subset of generative functionality internally, and it is designed to be used with arbitrary "stochastic function" -implementations. In particular, it can be used with other (non-PydanticAI) AI libraries, agent frameworks, etc. +implementations. In particular, it can be used with other (non-Pydantic AI) AI libraries, agent frameworks, etc. -As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific +As with Pydantic AI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. Full documentation is available at [ai.pydantic.dev/evals](https://ai.pydantic.dev/evals). ## Example -While you'd typically use Pydantic Evals with more complex functions (such as PydanticAI agents or graphs), here's a +While you'd typically use Pydantic Evals with more complex functions (such as Pydantic AI agents or graphs), here's a quick example that evaluates a simple function against a test case using both custom and built-in evaluators: ```python @@ -68,7 +68,7 @@ report.print(include_input=True, include_output=True) """ ``` -Using the library with more complex functions, such as PydanticAI agents, is similar — all you need to do is define a +Using the library with more complex functions, such as Pydantic AI agents, is similar — all you need to do is define a task function wrapping the function you want to evaluate, with a signature that matches the inputs and outputs of your test cases. diff --git a/pydantic_evals/pydantic_evals/generation.py b/pydantic_evals/pydantic_evals/generation.py index 37212cf21e..c1e68a6ea8 100644 --- a/pydantic_evals/pydantic_evals/generation.py +++ b/pydantic_evals/pydantic_evals/generation.py @@ -47,7 +47,7 @@ async def generate_dataset( path: Optional path to save the generated dataset. If provided, the dataset will be saved to this location. dataset_type: The type of dataset to generate, with the desired input, output, and metadata types. custom_evaluator_types: Optional sequence of custom evaluator classes to include in the schema. - model: The PydanticAI model to use for generation. Defaults to 'gpt-4o'. + model: The Pydantic AI model to use for generation. Defaults to 'gpt-4o'. n_examples: Number of examples to generate. Defaults to 3. extra_instructions: Optional additional instructions to provide to the LLM. @@ -59,7 +59,7 @@ async def generate_dataset( """ output_schema = dataset_type.model_json_schema_with_evaluators(custom_evaluator_types) - # TODO(DavidM): Update this once we add better response_format and/or ResultTool support to PydanticAI + # TODO(DavidM): Update this once we add better response_format and/or ResultTool support to Pydantic AI agent = Agent( model, system_prompt=( diff --git a/pydantic_graph/README.md b/pydantic_graph/README.md index 99f4051807..b942b36654 100644 --- a/pydantic_graph/README.md +++ b/pydantic_graph/README.md @@ -8,10 +8,10 @@ Graph and finite state machine library. -This library is developed as part of [PydanticAI](https://ai.pydantic.dev), however it has no dependency -on `pydantic-ai` or related packages and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using PydanticAI or even building with GenAI. +This library is developed as part of [Pydantic AI](https://ai.pydantic.dev), however it has no dependency +on `pydantic-ai` or related packages and can be considered as a pure graph-based state machine library. You may find it useful whether or not you're using Pydantic AI or even building with GenAI. -As with PydanticAI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. +As with Pydantic AI, this library prioritizes type safety and use of common Python syntax over esoteric, domain-specific use of Python syntax. `pydantic-graph` allows you to define graphs using standard Python syntax. In particular, edges are defined using the return type hint of nodes. diff --git a/tests/cassettes/test_mcp/test_tool_returning_text_resource.yaml b/tests/cassettes/test_mcp/test_tool_returning_text_resource.yaml index b1f6508caa..407bd95dbb 100644 --- a/tests/cassettes/test_mcp/test_tool_returning_text_resource.yaml +++ b/tests/cassettes/test_mcp/test_tool_returning_text_resource.yaml @@ -1,341 +1,345 @@ interactions: -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '1882' - content-type: - - application/json - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: Get me the product name - role: user - model: gpt-4o - n: 1 - stream: false - tool_choice: auto - tools: - - function: - description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n - \ Temperature in Fahrenheit\n " - name: celsius_to_fahrenheit - parameters: - properties: - celsius: - type: number - required: - - celsius - type: object - type: function - - function: - description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather - forecast for.\n\n Returns:\n The weather forecast for the location.\n " - name: get_weather_forecast - parameters: - properties: - location: - type: string - required: - - location - type: object - type: function - - function: - description: '' - name: get_image_resource - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_product_name - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_image - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_dict - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_error - parameters: - properties: - value: - type: boolean - type: object - type: function - - function: - description: '' - name: get_none - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_multiple_items - parameters: - properties: {} - type: object - type: function - - function: - description: "Get the current log level.\n\n Returns:\n The current log level.\n " - name: get_log_level - parameters: - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '1068' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '3650' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: tool_calls - index: 0 - logprobs: null - message: - annotations: [] - content: null - refusal: null - role: assistant - tool_calls: + - request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '1882' + content-type: + - application/json + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: Get me the product name + role: user + model: gpt-4o + n: 1 + stream: false + tool_choice: auto + tools: - function: - arguments: '{}' + description: + "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: + "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' name: get_product_name - id: call_LaiWltzI39sdquflqeuF0EyE + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object + type: function + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '1068' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '3650' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: tool_calls + index: 0 + logprobs: null + message: + annotations: [] + content: null + refusal: null + role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_product_name + id: call_LaiWltzI39sdquflqeuF0EyE + type: function + created: 1745961790 + id: chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 200 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 212 + status: + code: 200 + message: OK + - request: + headers: + accept: + - application/json + accept-encoding: + - gzip, deflate + connection: + - keep-alive + content-length: + - '2117' + content-type: + - application/json + cookie: + - __cf_bm=LCwj6B2rTuTfMe.JFAULcM1w5d9_bQkgyyDVrYXlFWQ-1745961790-1.0.1.1-rLSFIG9L0nbQaHsDaUAe231glaNUGZIodlFyJvNpkdF95kQD8prfC.uNV9.d2ymwvSDsmdB57U6u9ShNfBes9Ev8kn6eYDTHyGzxCeAhZ_o; + _cfuvid=eK9nRUfAL4vFjm9wuH.RIQX41iZHZ8h1LCjqR.nSQzA-1745961790721-0.0.1.1-604800000 + host: + - api.openai.com + method: POST + parsed_body: + messages: + - content: Get me the product name + role: user + - role: assistant + tool_calls: + - function: + arguments: '{}' + name: get_product_name + id: call_LaiWltzI39sdquflqeuF0EyE + type: function + - content: Pydantic AI + role: tool + tool_call_id: call_LaiWltzI39sdquflqeuF0EyE + model: gpt-4o + n: 1 + stream: false + tool_choice: auto + tools: + - function: + description: + "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n + \ Temperature in Fahrenheit\n " + name: celsius_to_fahrenheit + parameters: + properties: + celsius: + type: number + required: + - celsius + type: object + type: function + - function: + description: + "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather + forecast for.\n\n Returns:\n The weather forecast for the location.\n " + name: get_weather_forecast + parameters: + properties: + location: + type: string + required: + - location + type: object + type: function + - function: + description: '' + name: get_image_resource + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_product_name + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_image + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_dict + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_error + parameters: + properties: + value: + type: boolean + type: object + type: function + - function: + description: '' + name: get_none + parameters: + properties: {} + type: object + type: function + - function: + description: '' + name: get_multiple_items + parameters: + properties: {} + type: object + type: function + - function: + description: "Get the current log level.\n\n Returns:\n The current log level.\n " + name: get_log_level + parameters: + properties: {} + type: object type: function - created: 1745961790 - id: chatcmpl-BRmhyweJVYonarb7s9ckIMSHf2vHo - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_f5bdcc3276 - usage: - completion_tokens: 12 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 200 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 212 - status: - code: 200 - message: OK -- request: - headers: - accept: - - application/json - accept-encoding: - - gzip, deflate - connection: - - keep-alive - content-length: - - '2117' - content-type: - - application/json - cookie: - - __cf_bm=LCwj6B2rTuTfMe.JFAULcM1w5d9_bQkgyyDVrYXlFWQ-1745961790-1.0.1.1-rLSFIG9L0nbQaHsDaUAe231glaNUGZIodlFyJvNpkdF95kQD8prfC.uNV9.d2ymwvSDsmdB57U6u9ShNfBes9Ev8kn6eYDTHyGzxCeAhZ_o; - _cfuvid=eK9nRUfAL4vFjm9wuH.RIQX41iZHZ8h1LCjqR.nSQzA-1745961790721-0.0.1.1-604800000 - host: - - api.openai.com - method: POST - parsed_body: - messages: - - content: Get me the product name - role: user - - role: assistant - tool_calls: - - function: - arguments: '{}' - name: get_product_name - id: call_LaiWltzI39sdquflqeuF0EyE - type: function - - content: PydanticAI - role: tool - tool_call_id: call_LaiWltzI39sdquflqeuF0EyE - model: gpt-4o - n: 1 - stream: false - tool_choice: auto - tools: - - function: - description: "Convert Celsius to Fahrenheit.\n\n Args:\n celsius: Temperature in Celsius\n\n Returns:\n - \ Temperature in Fahrenheit\n " - name: celsius_to_fahrenheit - parameters: - properties: - celsius: - type: number - required: - - celsius - type: object - type: function - - function: - description: "Get the weather forecast for a location.\n\n Args:\n location: The location to get the weather - forecast for.\n\n Returns:\n The weather forecast for the location.\n " - name: get_weather_forecast - parameters: - properties: - location: - type: string - required: - - location - type: object - type: function - - function: - description: '' - name: get_image_resource - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_product_name - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_image - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_dict - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_error - parameters: - properties: - value: - type: boolean - type: object - type: function - - function: - description: '' - name: get_none - parameters: - properties: {} - type: object - type: function - - function: - description: '' - name: get_multiple_items - parameters: - properties: {} - type: object - type: function - - function: - description: "Get the current log level.\n\n Returns:\n The current log level.\n " - name: get_log_level - parameters: - properties: {} - type: object - type: function - uri: https://api.openai.com/v1/chat/completions - response: - headers: - access-control-expose-headers: - - X-Request-ID - alt-svc: - - h3=":443"; ma=86400 - connection: - - keep-alive - content-length: - - '839' - content-type: - - application/json - openai-organization: - - pydantic-28gund - openai-processing-ms: - - '631' - openai-version: - - '2020-10-01' - strict-transport-security: - - max-age=31536000; includeSubDomains; preload - transfer-encoding: - - chunked - parsed_body: - choices: - - finish_reason: stop - index: 0 - logprobs: null - message: - annotations: [] - content: The product name is "PydanticAI". - refusal: null - role: assistant - created: 1745961791 - id: chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR - model: gpt-4o-2024-08-06 - object: chat.completion - service_tier: default - system_fingerprint: fp_f5bdcc3276 - usage: - completion_tokens: 12 - completion_tokens_details: - accepted_prediction_tokens: 0 - audio_tokens: 0 - reasoning_tokens: 0 - rejected_prediction_tokens: 0 - prompt_tokens: 224 - prompt_tokens_details: - audio_tokens: 0 - cached_tokens: 0 - total_tokens: 236 - status: - code: 200 - message: OK + uri: https://api.openai.com/v1/chat/completions + response: + headers: + access-control-expose-headers: + - X-Request-ID + alt-svc: + - h3=":443"; ma=86400 + connection: + - keep-alive + content-length: + - '839' + content-type: + - application/json + openai-organization: + - pydantic-28gund + openai-processing-ms: + - '631' + openai-version: + - '2020-10-01' + strict-transport-security: + - max-age=31536000; includeSubDomains; preload + transfer-encoding: + - chunked + parsed_body: + choices: + - finish_reason: stop + index: 0 + logprobs: null + message: + annotations: [] + content: The product name is "Pydantic AI". + refusal: null + role: assistant + created: 1745961791 + id: chatcmpl-BRmhzqXFObpYwSzREMpJvX9kbDikR + model: gpt-4o-2024-08-06 + object: chat.completion + service_tier: default + system_fingerprint: fp_f5bdcc3276 + usage: + completion_tokens: 12 + completion_tokens_details: + accepted_prediction_tokens: 0 + audio_tokens: 0 + reasoning_tokens: 0 + rejected_prediction_tokens: 0 + prompt_tokens: 224 + prompt_tokens_details: + audio_tokens: 0 + cached_tokens: 0 + total_tokens: 236 + status: + code: 200 + message: OK version: 1 diff --git a/tests/example_modules/mcp_server.py b/tests/example_modules/mcp_server.py index 67f78cf54a..2c611048a9 100644 --- a/tests/example_modules/mcp_server.py +++ b/tests/example_modules/mcp_server.py @@ -4,7 +4,7 @@ from mcp.server.session import ServerSessionT from mcp.shared.context import LifespanContextT, RequestT -mcp = FastMCP('PydanticAI MCP Server') +mcp = FastMCP('Pydantic AI MCP Server') @mcp.tool() diff --git a/tests/mcp_server.py b/tests/mcp_server.py index 6eb619eded..084cd66e17 100644 --- a/tests/mcp_server.py +++ b/tests/mcp_server.py @@ -8,7 +8,7 @@ from mcp.types import BlobResourceContents, EmbeddedResource, SamplingMessage, TextContent, TextResourceContents from pydantic import AnyUrl -mcp = FastMCP('PydanticAI MCP Server') +mcp = FastMCP('Pydantic AI MCP Server') log_level = 'unset' @@ -70,7 +70,7 @@ async def get_product_name() -> EmbeddedResource: type='resource', resource=TextResourceContents( uri='resource://product_name.txt', # type: ignore - text='PydanticAI', + text='Pydantic AI', ), ) diff --git a/tests/models/cassettes/test_download_item/test_download_item_no_content_type.yaml b/tests/models/cassettes/test_download_item/test_download_item_no_content_type.yaml index ef3decd441..cb95ab3927 100644 --- a/tests/models/cassettes/test_download_item/test_download_item_no_content_type.yaml +++ b/tests/models/cassettes/test_download_item/test_download_item_no_content_type.yaml @@ -1,9 +1,9 @@ interactions: - request: - body: "" + body: '' headers: accept: - - "*/*" + - '*/*' accept-encoding: - gzip, deflate connection: @@ -17,17 +17,17 @@ interactions: string: | # Getting Help - If you need help getting started with PydanticAI or with advanced usage, the following sources may be useful. + If you need help getting started with Pydantic AI or with advanced usage, the following sources may be useful. ## :simple-slack: Slack - Join the `#pydantic-ai` channel in the [Pydantic Slack][slack] to ask questions, get help, and chat about PydanticAI. There's also channels for Pydantic, Logfire, and FastUI. + Join the `#pydantic-ai` channel in the [Pydantic Slack][slack] to ask questions, get help, and chat about Pydantic AI. There's also channels for Pydantic, Logfire, and FastUI. If you're on a [Logfire][logfire] Pro plan, you can also get a dedicated private slack collab channel with us. ## :simple-github: GitHub Issues - The [PydanticAI GitHub Issues][github-issues] are a great place to ask questions and give us feedback. + The [Pydantic AI GitHub Issues][github-issues] are a great place to ask questions and give us feedback. [slack]: https://logfire.pydantic.dev/docs/join-slack/ [github-issues]: https://github.com/pydantic/pydantic-ai/issues @@ -36,13 +36,13 @@ interactions: accept-ranges: - bytes access-control-allow-origin: - - "*" + - '*' cache-control: - max-age=300 connection: - keep-alive content-length: - - "737" + - '737' content-security-policy: - default-src 'none'; style-src 'unsafe-inline'; sandbox cross-origin-resource-policy: @@ -52,7 +52,7 @@ interactions: expires: - Sun, 01 Jun 2025 23:17:48 GMT source-age: - - "0" + - '0' strict-transport-security: - max-age=31536000 vary: diff --git a/tests/test_cli.py b/tests/test_cli.py index 8efc0da005..dd0ee175a9 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -30,7 +30,7 @@ def test_cli_version(capfd: CaptureFixture[str]): assert cli(['--version']) == 0 - assert capfd.readouterr().out.startswith('pai - PydanticAI CLI') + assert capfd.readouterr().out.startswith('pai - Pydantic AI CLI') def test_invalid_model(capfd: CaptureFixture[str]): @@ -131,7 +131,7 @@ def test_agent_flag_bad_module_variable_path(capfd: CaptureFixture[str], mocker: def test_list_models(capfd: CaptureFixture[str]): assert cli(['--list-models']) == 0 output = capfd.readouterr().out.splitlines() - assert output[:3] == snapshot([IsStr(regex='pai - PydanticAI CLI .*'), '', 'Available models:']) + assert output[:3] == snapshot([IsStr(regex='pai - Pydantic AI CLI .*'), '', 'Available models:']) providers = ( 'openai', diff --git a/tests/test_mcp.py b/tests/test_mcp.py index e2c5bd0989..94528b40ad 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -371,7 +371,7 @@ async def test_tool_returning_str(allow_model_requests: None, agent: Agent): async def test_tool_returning_text_resource(allow_model_requests: None, agent: Agent): async with agent: result = await agent.run('Get me the product name') - assert result.output == snapshot('The product name is "PydanticAI".') + assert result.output == snapshot('The product name is "Pydantic AI".') assert result.all_messages() == snapshot( [ ModelRequest( @@ -411,14 +411,14 @@ async def test_tool_returning_text_resource(allow_model_requests: None, agent: A parts=[ ToolReturnPart( tool_name='get_product_name', - content='PydanticAI', + content='Pydantic AI', tool_call_id='call_LaiWltzI39sdquflqeuF0EyE', timestamp=IsDatetime(), ) ] ), ModelResponse( - parts=[TextPart(content='The product name is "PydanticAI".')], + parts=[TextPart(content='The product name is "Pydantic AI".')], usage=Usage( requests=1, request_tokens=224, From e6396cc6eb5b7a1c3b67d3d9a1670bc2db774acd Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 17 Jul 2025 07:34:12 -0700 Subject: [PATCH 29/89] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method?= =?UTF-8?q?=20`AgentRunResult.=5Fset=5Foutput=5Ftool=5Freturn`=20by=201879?= =?UTF-8?q?8%=20(#2196)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> --- pydantic_ai_slim/pydantic_ai/agent.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 2b0aeb597e..d0275bcbbf 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -2218,12 +2218,18 @@ def _set_output_tool_return(self, return_content: str) -> list[_messages.ModelMe """ if not self._output_tool_name: raise ValueError('Cannot set output tool return content when the return type is `str`.') - messages = deepcopy(self._state.message_history) + + messages = self._state.message_history last_message = messages[-1] - for part in last_message.parts: + for idx, part in enumerate(last_message.parts): if isinstance(part, _messages.ToolReturnPart) and part.tool_name == self._output_tool_name: - part.content = return_content - return messages + # Only do deepcopy when we have to modify + copied_messages = list(messages) + copied_last = deepcopy(last_message) + copied_last.parts[idx].content = return_content # type: ignore[misc] + copied_messages[-1] = copied_last + return copied_messages + raise LookupError(f'No tool call found with tool name {self._output_tool_name!r}.') @overload From 8f63ba3c69db487c9c28ad647cdbe6aff9afdb65 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Thu, 17 Jul 2025 07:43:41 -0700 Subject: [PATCH 30/89] =?UTF-8?q?=E2=9A=A1=EF=B8=8F=20Speed=20up=20method?= =?UTF-8?q?=20`Usage.opentelemetry=5Fattributes`=20by=2085%=20(#2198)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: codeflash-ai[bot] <148906541+codeflash-ai[bot]@users.noreply.github.com> --- pydantic_ai_slim/pydantic_ai/usage.py | 20 +++++++++++++------- tests/test_usage_limits.py | 9 +++++++++ 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/usage.py b/pydantic_ai_slim/pydantic_ai/usage.py index bba0f241a3..c3f4c1885b 100644 --- a/pydantic_ai_slim/pydantic_ai/usage.py +++ b/pydantic_ai_slim/pydantic_ai/usage.py @@ -57,13 +57,19 @@ def __add__(self, other: Usage) -> Usage: def opentelemetry_attributes(self) -> dict[str, int]: """Get the token limits as OpenTelemetry attributes.""" - result = { - 'gen_ai.usage.input_tokens': self.request_tokens, - 'gen_ai.usage.output_tokens': self.response_tokens, - } - for key, value in (self.details or {}).items(): - result[f'gen_ai.usage.details.{key}'] = value # pragma: no cover - return {k: v for k, v in result.items() if v} + result: dict[str, int] = {} + if self.request_tokens: + result['gen_ai.usage.input_tokens'] = self.request_tokens + if self.response_tokens: + result['gen_ai.usage.output_tokens'] = self.response_tokens + details = self.details + if details: + prefix = 'gen_ai.usage.details.' + for key, value in details.items(): + # Skipping check for value since spec implies all detail values are relevant + if value: + result[prefix + key] = value + return result def has_values(self) -> bool: """Whether any values are set and non-zero.""" diff --git a/tests/test_usage_limits.py b/tests/test_usage_limits.py index 6356664c0d..7fb9bba485 100644 --- a/tests/test_usage_limits.py +++ b/tests/test_usage_limits.py @@ -176,6 +176,15 @@ async def delegate_to_other_agent2(ctx: RunContext[None], sentence: str) -> int: # confirm the usage from result2 is the sum of the usage from result1 assert result2.usage() == functools.reduce(operator.add, run_1_usages) + result1_usage = result1.usage() + result1_usage.details = {'custom1': 10, 'custom2': 20, 'custom3': 0} + assert result1_usage.opentelemetry_attributes() == { + 'gen_ai.usage.input_tokens': 103, + 'gen_ai.usage.output_tokens': 13, + 'gen_ai.usage.details.custom1': 10, + 'gen_ai.usage.details.custom2': 20, + } + async def test_multi_agent_usage_sync(): """As in `test_multi_agent_usage_async`, with a sync tool.""" From 61f526033c5499ecb8b106cc4785e319fdbce3bd Mon Sep 17 00:00:00 2001 From: alm Date: Thu, 17 Jul 2025 17:48:08 +0300 Subject: [PATCH 31/89] Nicer errors under the capture_run_messages context (#2219) --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 17 ++++++++++++----- tests/test_agent.py | 12 ++++++++++++ 2 files changed, 24 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index f6b4a51c3f..dbe8d13cb4 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -815,14 +815,21 @@ def capture_run_messages() -> Iterator[list[_messages.ModelMessage]]: If you call `run`, `run_sync`, or `run_stream` more than once within a single `capture_run_messages` context, `messages` will represent the messages exchanged during the first call only. """ + token = None + messages: list[_messages.ModelMessage] = [] + + # Try to reuse existing message context if available try: - yield _messages_ctx_var.get().messages + messages = _messages_ctx_var.get().messages except LookupError: - messages: list[_messages.ModelMessage] = [] + # No existing context, create a new one token = _messages_ctx_var.set(_RunMessages(messages)) - try: - yield messages - finally: + + try: + yield messages + finally: + # Clean up context if we created it + if token is not None: _messages_ctx_var.reset(token) diff --git a/tests/test_agent.py b/tests/test_agent.py index c1893fdcb4..df92ccd500 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2673,6 +2673,18 @@ def test_double_capture_run_messages() -> None: ) +def test_capture_run_messages_with_user_exception_does_not_contain_internal_errors() -> None: + """Test that user exceptions within capture_run_messages context have clean stack traces.""" + agent = Agent('test') + + try: + with capture_run_messages(): + agent.run_sync('Hello') + raise ZeroDivisionError('division by zero') + except Exception as e: + assert e.__context__ is None + + def test_dynamic_false_no_reevaluate(): """When dynamic is false (default), the system prompt is not reevaluated i.e: SystemPromptPart( From c568ee91227007ac8c527b1a39b95f80a3a5e91b Mon Sep 17 00:00:00 2001 From: Zach Deane-Mayer <581590+zachmayer@users.noreply.github.com> Date: Thu, 17 Jul 2025 11:22:25 -0400 Subject: [PATCH 32/89] Add OpenAI o1-pro, o3-pro, o3-deep-research, computer-use models (#2234) --- .../pydantic_ai/models/__init__.py | 19 +++++++++++++++++++ pydantic_ai_slim/pydantic_ai/models/openai.py | 4 ++-- pydantic_ai_slim/pyproject.toml | 2 +- uv.lock | 8 ++++---- 4 files changed, 26 insertions(+), 7 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 11ec50f85b..665394af69 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -176,6 +176,7 @@ 'gpt-4o-audio-preview', 'gpt-4o-audio-preview-2024-10-01', 'gpt-4o-audio-preview-2024-12-17', + 'gpt-4o-audio-preview-2025-06-03', 'gpt-4o-mini', 'gpt-4o-mini-2024-07-18', 'gpt-4o-mini-audio-preview', @@ -229,11 +230,18 @@ 'o1-mini-2024-09-12', 'o1-preview', 'o1-preview-2024-09-12', + 'o1-pro', + 'o1-pro-2025-03-19', 'o3', 'o3-2025-04-16', + 'o3-deep-research', + 'o3-deep-research-2025-06-26', 'o3-mini', 'o3-mini-2025-01-31', + 'o3-pro', + 'o3-pro-2025-06-10', 'openai:chatgpt-4o-latest', + 'openai:codex-mini-latest', 'openai:gpt-3.5-turbo', 'openai:gpt-3.5-turbo-0125', 'openai:gpt-3.5-turbo-0301', @@ -266,6 +274,7 @@ 'openai:gpt-4o-audio-preview', 'openai:gpt-4o-audio-preview-2024-10-01', 'openai:gpt-4o-audio-preview-2024-12-17', + 'openai:gpt-4o-audio-preview-2025-06-03', 'openai:gpt-4o-mini', 'openai:gpt-4o-mini-2024-07-18', 'openai:gpt-4o-mini-audio-preview', @@ -280,12 +289,22 @@ 'openai:o1-mini-2024-09-12', 'openai:o1-preview', 'openai:o1-preview-2024-09-12', + 'openai:o1-pro', + 'openai:o1-pro-2025-03-19', 'openai:o3', 'openai:o3-2025-04-16', + 'openai:o3-deep-research', + 'openai:o3-deep-research-2025-06-26', 'openai:o3-mini', 'openai:o3-mini-2025-01-31', 'openai:o4-mini', 'openai:o4-mini-2025-04-16', + 'openai:o4-mini-deep-research', + 'openai:o4-mini-deep-research-2025-06-26', + 'openai:o3-pro', + 'openai:o3-pro-2025-06-10', + 'openai:computer-use-preview', + 'openai:computer-use-preview-2025-03-11', 'test', ], ) diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 92d79c6340..a87f9ef3bd 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -50,7 +50,7 @@ try: from openai import NOT_GIVEN, APIStatusError, AsyncOpenAI, AsyncStream, NotGiven - from openai.types import ChatModel, chat, responses + from openai.types import AllModels, chat, responses from openai.types.chat import ( ChatCompletionChunk, ChatCompletionContentPartImageParam, @@ -80,7 +80,7 @@ 'OpenAIModelName', ) -OpenAIModelName = Union[str, ChatModel] +OpenAIModelName = Union[str, AllModels] """ Possible OpenAI model names. diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 4b62e40d98..279b0e8f87 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -62,7 +62,7 @@ dependencies = [ # WARNING if you add optional groups, please update docs/install.md logfire = ["logfire>=3.11.0"] # Models -openai = ["openai>=1.76.0"] +openai = ["openai>=1.92.0"] cohere = ["cohere>=5.13.11; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.24.0"] diff --git a/uv.lock b/uv.lock index bea12cf7af..28e85eaa67 100644 --- a/uv.lock +++ b/uv.lock @@ -2327,7 +2327,7 @@ wheels = [ [[package]] name = "openai" -version = "1.76.0" +version = "1.97.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -2339,9 +2339,9 @@ dependencies = [ { name = "tqdm" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/84/51/817969ec969b73d8ddad085670ecd8a45ef1af1811d8c3b8a177ca4d1309/openai-1.76.0.tar.gz", hash = "sha256:fd2bfaf4608f48102d6b74f9e11c5ecaa058b60dad9c36e409c12477dfd91fb2", size = 434660, upload-time = "2025-04-23T16:33:53.266Z" } +sdist = { url = "https://files.pythonhosted.org/packages/e0/c6/b8d66e4f3b95493a8957065b24533333c927dc23817abe397f13fe589c6e/openai-1.97.0.tar.gz", hash = "sha256:0be349569ccaa4fb54f97bb808423fd29ccaeb1246ee1be762e0c81a47bae0aa", size = 493850, upload-time = "2025-07-16T16:37:35.196Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/59/aa/84e02ab500ca871eb8f62784426963a1c7c17a72fea3c7f268af4bbaafa5/openai-1.76.0-py3-none-any.whl", hash = "sha256:a712b50e78cf78e6d7b2a8f69c4978243517c2c36999756673e07a14ce37dc0a", size = 661201, upload-time = "2025-04-23T16:33:51.12Z" }, + { url = "https://files.pythonhosted.org/packages/8a/91/1f1cf577f745e956b276a8b1d3d76fa7a6ee0c2b05db3b001b900f2c71db/openai-1.97.0-py3-none-any.whl", hash = "sha256:a1c24d96f4609f3f7f51c9e1c2606d97cc6e334833438659cfd687e9c972c610", size = 764953, upload-time = "2025-07-16T16:37:33.135Z" }, ] [[package]] @@ -3192,7 +3192,7 @@ requires-dist = [ { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "mcp", marker = "python_full_version >= '3.10' and extra == 'mcp'", specifier = ">=1.9.4" }, { name = "mistralai", marker = "extra == 'mistral'", specifier = ">=1.2.5" }, - { name = "openai", marker = "extra == 'openai'", specifier = ">=1.76.0" }, + { name = "openai", marker = "extra == 'openai'", specifier = ">=1.92.0" }, { name = "opentelemetry-api", specifier = ">=1.28.0" }, { name = "prompt-toolkit", marker = "extra == 'cli'", specifier = ">=3" }, { name = "pydantic", specifier = ">=2.10" }, From 0b3bf866172f724d9919b7e11e1e240de8111b8d Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Thu, 17 Jul 2025 19:11:54 +0100 Subject: [PATCH 33/89] chore(mistral): disable model_fields deprecation warning (#2224) Co-authored-by: Marcelo Trylesinski --- pyproject.toml | 2 ++ 1 file changed, 2 insertions(+) diff --git a/pyproject.toml b/pyproject.toml index b2b82aa867..c3d4ecc60e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -200,6 +200,8 @@ filterwarnings = [ "error", # Issue with python-multipart - we don't want to bump the minimum version of starlette. "ignore::PendingDeprecationWarning:starlette", + # mistralai accesses model_fields on the instance, which is deprecated in Pydantic 2.11. + "ignore:Accessing the 'model_fields' attribute", # boto3 "ignore::DeprecationWarning:botocore.*", "ignore::RuntimeWarning:pydantic_ai.mcp", From 490c3b42906cccf863687aa94bd95223a13d7445 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 17 Jul 2025 17:37:09 -0600 Subject: [PATCH 34/89] Toolsets public interface and docs tweaks (#2241) --- docs/toolsets.md | 25 +++++++++++-------- pydantic_ai_slim/pydantic_ai/output.py | 5 +++- pydantic_ai_slim/pydantic_ai/tools.py | 4 +-- .../pydantic_ai/toolsets/abstract.py | 16 +++--------- .../pydantic_ai/toolsets/combined.py | 2 +- .../pydantic_ai/toolsets/deferred.py | 2 +- .../pydantic_ai/toolsets/wrapper.py | 4 +-- 7 files changed, 27 insertions(+), 31 deletions(-) diff --git a/docs/toolsets.md b/docs/toolsets.md index fa7073798b..066d54ef90 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -334,9 +334,7 @@ print(test_model.last_model_request_parameters.function_tools) [`WrapperToolset`][pydantic_ai.toolsets.WrapperToolset] wraps another toolset and delegates all responsibility to it. -To easily chain different modifications, you can also call [`wrap()`][pydantic_ai.toolsets.AbstractToolset.wrap] on any toolset instead of directly constructing an instance of (a subclass of) `WrapperToolset`. - -`WrapperToolset` is a no-op by default, but enables some useful abilities: +It is is a no-op by default, but enables some useful abilities: #### Changing Tool Execution @@ -367,7 +365,7 @@ class LoggingToolset(WrapperToolset): return result -logging_toolset = prepared_toolset.wrap(LoggingToolset) +logging_toolset = LoggingToolset(prepared_toolset) agent = Agent(TestModel(), toolsets=[logging_toolset]) # (1)! result = agent.run_sync('Call all the tools') @@ -438,14 +436,17 @@ If you want to reuse a network connection or session across tool listings and ca ### Deferred Toolset -A deferred tool is one that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via a protocol like [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). +A deferred tool is one whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running. + +Deferred tools enable various use cases: -!!! note - This is not typically something you need to bother with, unless you are implementing support for such a protocol between an upstream tool provider and Pydantic AI. +- Support client-side tools implemented by a web or app frontend +- Implement a Human-in-the-Loop flow where the user needs to explicitly provide an "answer" before the run can continue +- Pass slow tasks off to a background worker or external service that will send a (webhook) notification when the result is ready and the agent run can be continued. -When the model calls a deferred tool, the agent run ends with a [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] object containing the deferred tool call names and arguments, which is expected to be returned to the upstream tool provider. This upstream service is then expected to generate a response for each tool call and start a new Pydantic AI agent run with the message history and new [`ToolReturnPart`s][pydantic_ai.messages.ToolReturnPart] corresponding to each deferred call, after which the run will continue. +When the model calls a deferred tool, the agent run ends with a [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] object containing the deferred tool call names and arguments, which are expected to be returned to the service that will (eventually) produce the result(s). Once all the results are ready, a new Pydantic AI agent run can then be started with the original run's message history plus new [`ToolReturnPart`s][pydantic_ai.messages.ToolReturnPart] (or [`RetryPromptPart`s][pydantic_ai.messages.RetryPromptPart] in case of failure) corresponding to each deferred call, after which the run will continue. -To enable an agent to call deferred tools, you create a [`DeferredToolset`][pydantic_ai.toolsets.DeferredToolset], pass it a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition], and provide it to the agent using one of the methods described above. Additionally, you need to add `DeferredToolCalls` to the `Agent`'s [output types](output.md#structured-output) so that the agent run's output type is correctly inferred. Finally, you should handle the possible `DeferredToolCalls` result by returning it to the upstream tool provider. +To enable an agent to call deferred tools, you create a [`DeferredToolset`][pydantic_ai.toolsets.DeferredToolset], pass it a list of [`ToolDefinition`s][pydantic_ai.tools.ToolDefinition], and provide it to the agent using one of the methods described above. Additionally, you need to add `DeferredToolCalls` to the `Agent`'s [`output_type`](output.md#structured-output) so that the possible types of the agent run output are correctly inferred. Finally, you should handle the possible `DeferredToolCalls` output by passing it to the service that will produce the results. If your agent can also be used in a context where no deferred tools are available, you will not want to include `DeferredToolCalls` in the `output_type` passed to the `Agent` constructor as you'd have to deal with that type everywhere you use the agent. Instead, you can pass the `toolsets` and `output_type` keyword arguments when you run the agent using [`agent.run()`][pydantic_ai.Agent.run], [`agent.run_sync()`][pydantic_ai.Agent.run_sync], [`agent.run_stream()`][pydantic_ai.Agent.run_stream], or [`agent.iter()`][pydantic_ai.Agent.iter]. Note that while `toolsets` provided at this stage are additional to the toolsets provided to the constructor, the `output_type` overrides the one specified at construction time (for type inference reasons), so you'll need to include the original output types explicitly. @@ -482,7 +483,7 @@ print(repr(result.output)) #> PersonalizedGreeting(greeting='Hello, David!', language_code='en-US') ``` -Next, let's define an function for a hypothetical "run agent" API endpoint that can be called by the frontend and takes a list of messages to send to the model plus a dict of frontend tool names and descriptions. This is where `DeferredToolset` and `DeferredToolCalls` come in: +Next, let's define a function that represents a hypothetical "run agent" API endpoint that can be called by the frontend and takes a list of messages to send to the model plus a list of frontend tool definitions. This is where `DeferredToolset` and `DeferredToolCalls` come in: ```python {title="deferred_toolset_api.py" requires="deferred_toolset_agent.py"} from deferred_toolset_agent import agent, PersonalizedGreeting @@ -526,8 +527,10 @@ frontend_tool_definitions = [ description="Get the user's preferred language from their browser", ) ] + def get_preferred_language(default_language: str) -> str: return 'es-MX' # (1)! + frontend_tool_functions = {'get_preferred_language': get_preferred_language} messages: list[ModelMessage] = [ @@ -578,7 +581,7 @@ PersonalizedGreeting(greeting='Hola, David! Espero que tengas un gran día!', la """ ``` -1. Imagine that this returns [`navigator.language`](https://developer.mozilla.org/en-US/docs/Web/API/Navigator/language) +1. Imagine that this returns the frontend [`navigator.language`](https://developer.mozilla.org/en-US/docs/Web/API/Navigator/language). _(This example is complete, it can be run "as is")_ diff --git a/pydantic_ai_slim/pydantic_ai/output.py b/pydantic_ai_slim/pydantic_ai/output.py index 9c41b535db..d61eb61748 100644 --- a/pydantic_ai_slim/pydantic_ai/output.py +++ b/pydantic_ai_slim/pydantic_ai/output.py @@ -361,7 +361,10 @@ def __get_pydantic_json_schema__( @dataclass class DeferredToolCalls: - """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools.""" + """Container for calls of deferred tools. This can be used as an agent's `output_type` and will be used as the output of the agent run if the model called any deferred tools. + + See [deferred toolset docs](../toolsets.md#deferred-toolset) for more information. + """ tool_calls: list[ToolCallPart] tool_defs: dict[str, ToolDefinition] diff --git a/pydantic_ai_slim/pydantic_ai/tools.py b/pydantic_ai_slim/pydantic_ai/tools.py index 4243c02971..14b5d12f3a 100644 --- a/pydantic_ai_slim/pydantic_ai/tools.py +++ b/pydantic_ai_slim/pydantic_ai/tools.py @@ -365,9 +365,9 @@ class ToolDefinition: kind: ToolKind = field(default='function') """The kind of tool: - - `'function'`: a tool that can be executed by Pydantic AI and has its result returned to the model + - `'function'`: a tool that will be executed by Pydantic AI during an agent run and has its result returned to the model - `'output'`: a tool that passes through an output value that ends the run - - `'deferred'`: a tool that will be executed not by Pydantic AI, but by the upstream service that called the agent, such as a web application that supports frontend-defined tools provided to Pydantic AI via e.g. [AG-UI](https://docs.ag-ui.com/concepts/tools#frontend-defined-tools). + - `'deferred'`: a tool whose result will be produced outside of the Pydantic AI agent run in which it was called, because it depends on an upstream service (or user) or could take longer to generate than it's reasonable to keep the agent process running. When the model calls a deferred tool, the agent run ends with a `DeferredToolCalls` object and a new run is expected to be started at a later point with the message history and new `ToolReturnPart`s corresponding to each deferred call. """ diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py index 0f19eec3bc..455336418f 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/abstract.py @@ -2,7 +2,7 @@ from abc import ABC, abstractmethod from dataclasses import dataclass -from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol, TypeVar +from typing import TYPE_CHECKING, Any, Callable, Generic, Literal, Protocol from pydantic_core import SchemaValidator from typing_extensions import Self @@ -15,9 +15,6 @@ from .prefixed import PrefixedToolset from .prepared import PreparedToolset from .renamed import RenamedToolset - from .wrapper import WrapperToolset - -WrapperT = TypeVar('WrapperT', bound='WrapperToolset[Any]') class SchemaValidatorProt(Protocol): @@ -115,9 +112,9 @@ async def call_tool( """ raise NotImplementedError() - def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: """Run a visitor function on all concrete toolsets that are not wrappers (i.e. they implement their own tool listing and calling).""" - return visitor(self) + visitor(self) def filtered( self, filter_func: Callable[[RunContext[AgentDepsT], ToolDefinition], bool] @@ -156,10 +153,3 @@ def renamed(self, name_map: dict[str, str]) -> RenamedToolset[AgentDepsT]: from .renamed import RenamedToolset return RenamedToolset(self, name_map) - - def wrap(self, wrapper_cls: type[WrapperT], *args: Any, **kwargs: Any) -> WrapperT: - """Returns an instance of the provided wrapper class wrapping this toolset, with all arguments passed to the wrapper class constructor. - - See [toolset docs](../toolsets.md#wrapping-a-toolset) for more information. - """ - return wrapper_cls(self, *args, **kwargs) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index a083477196..4b1511fae1 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -83,6 +83,6 @@ async def call_tool( assert isinstance(tool, _CombinedToolsetTool) return await tool.source_toolset.call_tool(name, tool_args, ctx, tool.source_tool) - def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: for toolset in self.toolsets: toolset.apply(visitor) diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py index 29964e9333..3ad2e976ba 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/deferred.py @@ -14,7 +14,7 @@ @dataclass class DeferredToolset(AbstractToolset[AgentDepsT]): - """A toolset that holds deferred tools that will be called by the upstream service that called the agent. + """A toolset that holds deferred tools whose results will be produced outside of the Pydantic AI agent run in which they were called. See [toolset docs](../toolsets.md#deferred-toolset), [`ToolDefinition.kind`][pydantic_ai.tools.ToolDefinition.kind], and [`DeferredToolCalls`][pydantic_ai.output.DeferredToolCalls] for more information. """ diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py index 1dddd96a51..8440f1c466 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/wrapper.py @@ -33,5 +33,5 @@ async def call_tool( ) -> Any: return await self.wrapped.call_tool(name, tool_args, ctx, tool) - def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], Any]) -> Any: - return self.wrapped.apply(visitor) + def apply(self, visitor: Callable[[AbstractToolset[AgentDepsT]], None]) -> None: + self.wrapped.apply(visitor) From babc23b30a94cbfa7d6ee9cebba25309b0cf8360 Mon Sep 17 00:00:00 2001 From: Forge <64839751+GDaamn@users.noreply.github.com> Date: Fri, 18 Jul 2025 16:09:59 +0200 Subject: [PATCH 35/89] add identifier field to BinaryContent class (#2231) Co-authored-by: Douwe Maan --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 2 +- pydantic_ai_slim/pydantic_ai/messages.py | 8 +++ tests/test_agent.py | 56 +++++++++++++++++++- 3 files changed, 63 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index dbe8d13cb4..312a8a2fca 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -759,7 +759,7 @@ def process_content(content: Any) -> Any: ) elif isinstance(content, _messages.MultiModalContentTypes): if isinstance(content, _messages.BinaryContent): - identifier = multi_modal_content_identifier(content.data) + identifier = content.identifier or multi_modal_content_identifier(content.data) else: identifier = multi_modal_content_identifier(content.url) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index a43771d87b..731ccb5ca6 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -282,6 +282,14 @@ class BinaryContent: media_type: AudioMediaType | ImageMediaType | DocumentMediaType | str """The media type of the binary data.""" + identifier: str | None = None + """Identifier for the binary content, such as a URL or unique ID. + + This identifier can be provided to the model in a message to allow it to refer to this file in a tool call argument, and the tool can look up the file in question by iterating over the message history and finding the matching `BinaryContent`. + + This identifier is only automatically passed to the model when the `BinaryContent` is returned by a tool. If you're passing the `BinaryContent` as a user message, it's up to you to include a separate text part with the identifier, e.g. "This is file :" preceding the `BinaryContent`. + """ + vendor_metadata: dict[str, Any] | None = None """Vendor-specific metadata for the file. diff --git a/tests/test_agent.py b/tests/test_agent.py index df92ccd500..a5a9285d52 100644 --- a/tests/test_agent.py +++ b/tests/test_agent.py @@ -2938,7 +2938,13 @@ def test_binary_content_all_messages_json(): { 'content': [ 'Hello', - {'data': 'SGVsbG8=', 'media_type': 'text/plain', 'vendor_metadata': None, 'kind': 'binary'}, + { + 'data': 'SGVsbG8=', + 'media_type': 'text/plain', + 'vendor_metadata': None, + 'kind': 'binary', + 'identifier': None, + }, ], 'timestamp': IsStr(), 'part_kind': 'user-prompt', @@ -2973,7 +2979,7 @@ def test_binary_content_all_messages_json(): def test_tool_return_part_binary_content_serialization(): """Test that ToolReturnPart can properly serialize BinaryContent.""" png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' - binary_content = BinaryContent(png_data, media_type='image/png') + binary_content = BinaryContent(png_data, media_type='image/png', identifier='image_id_1') tool_return = ToolReturnPart(tool_name='test_tool', content=binary_content, tool_call_id='test_call_123') @@ -2982,10 +2988,12 @@ def test_tool_return_part_binary_content_serialization(): assert '"kind":"binary"' in response_str assert '"media_type":"image/png"' in response_str assert '"data":"' in response_str + assert '"identifier":"image_id_1"' in response_str response_obj = tool_return.model_response_object() assert response_obj['return_value']['kind'] == 'binary' assert response_obj['return_value']['media_type'] == 'image/png' + assert response_obj['return_value']['identifier'] == 'image_id_1' assert 'data' in response_obj['return_value'] @@ -3011,6 +3019,50 @@ def get_image() -> BinaryContent: assert result.output == 'Image received' +def test_tool_returning_binary_content_with_identifier(): + """Test that a tool returning BinaryContent directly works correctly.""" + + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse(parts=[ToolCallPart('get_image', {})]) + else: + return ModelResponse(parts=[TextPart('Image received')]) + + agent = Agent(FunctionModel(llm)) + + @agent.tool_plain + def get_image() -> BinaryContent: + """Return a simple image.""" + png_data = b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82' + return BinaryContent(png_data, media_type='image/png', identifier='image_id_1') + + # This should work without the serialization error + result = agent.run_sync('Get an image') + assert result.all_messages()[2] == snapshot( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_image', + content='See file image_id_1', + tool_call_id=IsStr(), + timestamp=IsNow(tz=timezone.utc), + ), + UserPromptPart( + content=[ + 'This is file image_id_1:', + BinaryContent( + data=b'\x89PNG\r\n\x1a\n\x00\x00\x00\rIHDR\x00\x00\x00\x01\x00\x00\x00\x01\x08\x02\x00\x00\x00\x90wS\xde\x00\x00\x00\x0cIDATx\x9cc```\x00\x00\x00\x04\x00\x01\xf6\x178\x00\x00\x00\x00IEND\xaeB`\x82', + media_type='image/png', + identifier='image_id_1', + ), + ], + timestamp=IsNow(tz=timezone.utc), + ), + ] + ) + ) + + def test_instructions_raise_error_when_system_prompt_is_set(): agent = Agent('test', instructions='An instructions!') From 3c43c2dc5d3911fe925fc45e2b76a332ad592e24 Mon Sep 17 00:00:00 2001 From: Zach Deane-Mayer <581590+zachmayer@users.noreply.github.com> Date: Fri, 18 Jul 2025 10:30:35 -0400 Subject: [PATCH 36/89] Add grok-4 and groq kimi-k2 models (#2235) --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 9 +++++++++ pydantic_ai_slim/pydantic_ai/models/groq.py | 1 + .../pydantic_ai/profiles/moonshotai.py | 8 ++++++++ pydantic_ai_slim/pydantic_ai/providers/grok.py | 14 +++++++++++++- pydantic_ai_slim/pydantic_ai/providers/groq.py | 2 ++ .../test_model_names/test_known_model_names.yaml | 3 ++- tests/models/test_model.py | 16 ++++++++++++++++ tests/models/test_model_names.py | 3 +++ tests/providers/test_groq.py | 7 +++++++ tests/test_cli.py | 1 + 10 files changed, 62 insertions(+), 2 deletions(-) create mode 100644 pydantic_ai_slim/pydantic_ai/profiles/moonshotai.py diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 665394af69..6193d1d41f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -185,6 +185,14 @@ 'gpt-4o-mini-search-preview-2025-03-11', 'gpt-4o-search-preview', 'gpt-4o-search-preview-2025-03-11', + 'grok:grok-4', + 'grok:grok-4-0709', + 'grok:grok-3', + 'grok:grok-3-mini', + 'grok:grok-3-fast', + 'grok:grok-3-mini-fast', + 'grok:grok-2-vision-1212', + 'grok:grok-2-image-1212', 'groq:distil-whisper-large-v3-en', 'groq:gemma2-9b-it', 'groq:llama-3.3-70b-versatile', @@ -192,6 +200,7 @@ 'groq:llama-guard-3-8b', 'groq:llama3-70b-8192', 'groq:llama3-8b-8192', + 'groq:moonshotai/kimi-k2-instruct', 'groq:whisper-large-v3', 'groq:whisper-large-v3-turbo', 'groq:playai-tts', diff --git a/pydantic_ai_slim/pydantic_ai/models/groq.py b/pydantic_ai_slim/pydantic_ai/models/groq.py index bfdb1d3792..92376b44de 100644 --- a/pydantic_ai_slim/pydantic_ai/models/groq.py +++ b/pydantic_ai_slim/pydantic_ai/models/groq.py @@ -79,6 +79,7 @@ 'llama-3.2-3b-preview', 'llama-3.2-11b-vision-preview', 'llama-3.2-90b-vision-preview', + 'moonshotai/kimi-k2-instruct', ] """Preview Groq models from .""" diff --git a/pydantic_ai_slim/pydantic_ai/profiles/moonshotai.py b/pydantic_ai_slim/pydantic_ai/profiles/moonshotai.py new file mode 100644 index 0000000000..006b42bb6b --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/profiles/moonshotai.py @@ -0,0 +1,8 @@ +from __future__ import annotations as _annotations + +from . import ModelProfile + + +def moonshotai_model_profile(model_name: str) -> ModelProfile | None: + """Get the model profile for a MoonshotAI model.""" + return None diff --git a/pydantic_ai_slim/pydantic_ai/providers/grok.py b/pydantic_ai_slim/pydantic_ai/providers/grok.py index c232a29779..379b49795d 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/grok.py +++ b/pydantic_ai_slim/pydantic_ai/providers/grok.py @@ -1,7 +1,7 @@ from __future__ import annotations as _annotations import os -from typing import overload +from typing import Literal, overload from httpx import AsyncClient as AsyncHTTPClient from openai import AsyncOpenAI @@ -21,6 +21,18 @@ 'you can use the `openai` optional group — `pip install "pydantic-ai-slim[openai]"`' ) from _import_error +# https://docs.x.ai/docs/models +GrokModelName = Literal[ + 'grok-4', + 'grok-4-0709', + 'grok-3', + 'grok-3-mini', + 'grok-3-fast', + 'grok-3-mini-fast', + 'grok-2-vision-1212', + 'grok-2-image-1212', +] + class GrokProvider(Provider[AsyncOpenAI]): """Provider for Grok API.""" diff --git a/pydantic_ai_slim/pydantic_ai/providers/groq.py b/pydantic_ai_slim/pydantic_ai/providers/groq.py index 34968b2566..9deada0b09 100644 --- a/pydantic_ai_slim/pydantic_ai/providers/groq.py +++ b/pydantic_ai_slim/pydantic_ai/providers/groq.py @@ -12,6 +12,7 @@ from pydantic_ai.profiles.google import google_model_profile from pydantic_ai.profiles.meta import meta_model_profile from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.moonshotai import moonshotai_model_profile from pydantic_ai.profiles.qwen import qwen_model_profile from pydantic_ai.providers import Provider @@ -47,6 +48,7 @@ def model_profile(self, model_name: str) -> ModelProfile | None: 'qwen': qwen_model_profile, 'deepseek': deepseek_model_profile, 'mistral': mistral_model_profile, + 'moonshotai/': moonshotai_model_profile, } for prefix, profile_func in prefix_to_profile.items(): diff --git a/tests/models/cassettes/test_model_names/test_known_model_names.yaml b/tests/models/cassettes/test_model_names/test_known_model_names.yaml index 84f28587d2..4fc94b1460 100644 --- a/tests/models/cassettes/test_model_names/test_known_model_names.yaml +++ b/tests/models/cassettes/test_model_names/test_known_model_names.yaml @@ -15,7 +15,7 @@ interactions: response: headers: content-length: - - '545' + - '550' content-security-policy: - default-src 'none'; frame-ancestors 'none' content-type: @@ -46,6 +46,7 @@ interactions: - text-to-text - model_id: claude-4-sonnet regions: + - eu - us type: - text-to-text diff --git a/tests/models/test_model.py b/tests/models/test_model.py index 987e4ae728..32cbf5d3d2 100644 --- a/tests/models/test_model.py +++ b/tests/models/test_model.py @@ -72,6 +72,22 @@ 'github', 'OpenAIModel', ), + ( + 'GROK_API_KEY', + 'grok:grok-3', + 'grok-3', + 'grok', + 'grok', + 'OpenAIModel', + ), + ( + 'GROK_API_KEY', + 'grok-4', # Note that the provider and model name are both "grok", so the plain string grok with no prefix works because its also the provider name + 'grok-4', + 'grok', + 'grok', + 'OpenAIModel', + ), ] diff --git a/tests/models/test_model_names.py b/tests/models/test_model_names.py index db6f22cd8d..01166879d7 100644 --- a/tests/models/test_model_names.py +++ b/tests/models/test_model_names.py @@ -19,6 +19,7 @@ from pydantic_ai.models.huggingface import HuggingFaceModelName from pydantic_ai.models.mistral import MistralModelName from pydantic_ai.models.openai import OpenAIModelName + from pydantic_ai.providers.grok import GrokModelName pytestmark = [ pytest.mark.skipif(not imports_successful(), reason='some model package was not installed'), @@ -48,6 +49,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: google_names = [f'google-gla:{n}' for n in get_model_names(GeminiModelName)] + [ f'google-vertex:{n}' for n in get_model_names(GeminiModelName) ] + grok_names = [f'grok:{n}' for n in get_model_names(GrokModelName)] groq_names = [f'groq:{n}' for n in get_model_names(GroqModelName)] mistral_names = [f'mistral:{n}' for n in get_model_names(MistralModelName)] openai_names = [f'openai:{n}' for n in get_model_names(OpenAIModelName)] + [ @@ -63,6 +65,7 @@ def get_model_names(model_name_type: Any) -> Iterator[str]: anthropic_names + cohere_names + google_names + + grok_names + groq_names + mistral_names + openai_names diff --git a/tests/providers/test_groq.py b/tests/providers/test_groq.py index cdb6717f47..0f059f96de 100644 --- a/tests/providers/test_groq.py +++ b/tests/providers/test_groq.py @@ -12,6 +12,7 @@ from pydantic_ai.profiles.google import GoogleJsonSchemaTransformer, google_model_profile from pydantic_ai.profiles.meta import meta_model_profile from pydantic_ai.profiles.mistral import mistral_model_profile +from pydantic_ai.profiles.moonshotai import moonshotai_model_profile from pydantic_ai.profiles.qwen import qwen_model_profile from ..conftest import TestEnv, try_import @@ -74,6 +75,7 @@ def test_groq_provider_model_profile(mocker: MockerFixture): google_model_profile_mock = mocker.patch(f'{ns}.google_model_profile', wraps=google_model_profile) mistral_model_profile_mock = mocker.patch(f'{ns}.mistral_model_profile', wraps=mistral_model_profile) qwen_model_profile_mock = mocker.patch(f'{ns}.qwen_model_profile', wraps=qwen_model_profile) + moonshotai_model_profile_mock = mocker.patch(f'{ns}.moonshotai_model_profile', wraps=moonshotai_model_profile) meta_profile = provider.model_profile('meta-llama/Llama-Guard-4-12B') meta_model_profile_mock.assert_called_with('llama-guard-4-12b') @@ -103,5 +105,10 @@ def test_groq_provider_model_profile(mocker: MockerFixture): assert qwen_profile is not None assert qwen_profile.json_schema_transformer == InlineDefsJsonSchemaTransformer + # MoonshotAI model should remove the "moonshotai/" prefix before passing to profile + moonshotai_profile = provider.model_profile('moonshotai/kimi-k2-instruct') + moonshotai_model_profile_mock.assert_called_with('kimi-k2-instruct') + assert moonshotai_profile is None + unknown_profile = provider.model_profile('unknown-model') assert unknown_profile is None diff --git a/tests/test_cli.py b/tests/test_cli.py index dd0ee175a9..fa2566f3f0 100644 --- a/tests/test_cli.py +++ b/tests/test_cli.py @@ -144,6 +144,7 @@ def test_list_models(capfd: CaptureFixture[str]): 'cohere', 'deepseek', 'heroku', + 'grok', 'huggingface', ) models = {line.strip().split(' ')[0] for line in output[3:]} From 431ec33d11717e587dab01470dc151934c118fda Mon Sep 17 00:00:00 2001 From: Steven Hartland Date: Fri, 18 Jul 2025 18:11:25 +0100 Subject: [PATCH 37/89] Support AG-UI protocol for frontend-agent communication (#2223) Co-authored-by: Douwe Maan --- docs/ag-ui.md | 187 ++++ docs/api/ag_ui.md | 3 + docs/examples/ag-ui.md | 204 ++++ docs/install.md | 1 + docs/tools.md | 2 +- .../pydantic_ai_examples/ag_ui/__init__.py | 41 + .../pydantic_ai_examples/ag_ui/__main__.py | 9 + .../ag_ui/api/__init__.py | 19 + .../ag_ui/api/agentic_chat.py | 25 + .../ag_ui/api/agentic_generative_ui.py | 119 +++ .../ag_ui/api/human_in_the_loop.py | 26 + .../ag_ui/api/predictive_state_updates.py | 77 ++ .../ag_ui/api/shared_state.py | 137 +++ .../ag_ui/api/tool_based_generative_ui.py | 11 + examples/pydantic_ai_examples/py.typed | 0 examples/pyproject.toml | 2 +- mkdocs.yml | 4 + pydantic_ai_slim/pydantic_ai/ag_ui.py | 675 +++++++++++++ pydantic_ai_slim/pydantic_ai/agent.py | 105 +- .../pydantic_ai/models/function.py | 74 +- .../pydantic_ai/models/mistral.py | 2 +- pydantic_ai_slim/pyproject.toml | 2 + pyproject.toml | 2 +- tests/conftest.py | 6 + tests/test_ag_ui.py | 939 ++++++++++++++++++ uv.lock | 28 +- 26 files changed, 2665 insertions(+), 35 deletions(-) create mode 100644 docs/ag-ui.md create mode 100644 docs/api/ag_ui.md create mode 100644 docs/examples/ag-ui.md create mode 100644 examples/pydantic_ai_examples/ag_ui/__init__.py create mode 100644 examples/pydantic_ai_examples/ag_ui/__main__.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/__init__.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/agentic_chat.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/agentic_generative_ui.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/human_in_the_loop.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/predictive_state_updates.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/shared_state.py create mode 100644 examples/pydantic_ai_examples/ag_ui/api/tool_based_generative_ui.py create mode 100644 examples/pydantic_ai_examples/py.typed create mode 100644 pydantic_ai_slim/pydantic_ai/ag_ui.py create mode 100644 tests/test_ag_ui.py diff --git a/docs/ag-ui.md b/docs/ag-ui.md new file mode 100644 index 0000000000..918a03a40a --- /dev/null +++ b/docs/ag-ui.md @@ -0,0 +1,187 @@ +# Agent User Interaction (AG-UI) Protocol + +The [Agent User Interaction (AG-UI) Protocol](https://docs.ag-ui.com/introduction) is an open standard introduced by the +[CopilotKit](https://webflow.copilotkit.ai/blog/introducing-ag-ui-the-protocol-where-agents-meet-users) +team that standardises how frontend applications communicate with AI agents, with support for streaming, frontend tools, shared state, and custom events. + +Any Pydantic AI agent can be exposed as an AG-UI server using the [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] convenience method. + +!!! note + The AG-UI integration was originally built by the team at [Rocket Science](https://www.rocketscience.gg/) and contributed in collaboration with the Pydantic AI and CopilotKit teams. Thanks Rocket Science! + +## Installation + +The only dependencies are: + +- [ag-ui-protocol](https://docs.ag-ui.com/introduction): to provide the AG-UI types and encoder +- [starlette](https://www.starlette.io): to expose the AG-UI server as an [ASGI application](https://asgi.readthedocs.io/en/latest/) + +You can install Pydantic AI with the `ag-ui` extra to ensure you have all the +required AG-UI dependencies: + +```bash +pip/uv-add 'pydantic-ai-slim[ag-ui]' +``` + +To run the examples you'll also need: + +- [uvicorn](https://www.uvicorn.org/) or another ASGI compatible server + +```bash +pip/uv-add uvicorn +``` + +## Quick start + +To expose a Pydantic AI agent as an AG-UI server, you can use the [`Agent.to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] method: + +```py {title="agent_to_ag_ui.py" py="3.10" hl_lines="4"} +from pydantic_ai import Agent + +agent = Agent('openai:gpt-4.1', instructions='Be fun!') +app = agent.to_ag_ui() +``` + +Since `app` is an ASGI application, it can be used with any ASGI server: + +```shell +uvicorn agent_to_ag_ui:app --host 0.0.0.0 --port 9000 +``` + +This will expose the agent as an AG-UI server, and your frontend can start sending requests to it. + +The `to_ag_ui()` method accepts the same arguments as the [`Agent.iter()`][pydantic_ai.agent.Agent.iter] method as well as arguments that let you configure the [Starlette](https://www.starlette.io)-based ASGI app. + +## Design + +The Pydantic AI AG-UI integration supports all features of the spec: + +- [Events](https://docs.ag-ui.com/concepts/events) +- [Messages](https://docs.ag-ui.com/concepts/messages) +- [State Management](https://docs.ag-ui.com/concepts/state) +- [Tools](https://docs.ag-ui.com/concepts/tools) + +The app receives messages in the form of a +[`RunAgentInput`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) +which describes the details of a request being passed to the agent including +messages and state. These are then converted to Pydantic AI types and passed to the +agent which then process the request. + +Events from the agent, including tool calls, are converted to AG-UI events and +streamed back to the caller as Server-Sent Events (SSE). + +A user request may require multiple round trips between client UI and Pydantic AI +server, depending on the tools and events needed. + +## Features + +### State management + +The adapter provides full support for +[AG-UI state management](https://docs.ag-ui.com/concepts/state), which enables +real-time synchronization between agents and frontend applications. + +In the example below we have document state which is shared between the UI and +server using the [`StateDeps`][pydantic_ai.ag_ui.StateDeps] which implements the +[`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol that can be used to automatically +decode state contained in [`RunAgentInput.state`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) +when processing requests. + +```python {title="ag_ui_state.py" py="3.10"} +from pydantic import BaseModel + +from pydantic_ai import Agent +from pydantic_ai.ag_ui import StateDeps + + +class DocumentState(BaseModel): + """State for the document being written.""" + + document: str = '' + + +agent = Agent( + 'openai:gpt-4.1', + instructions='Be fun!', + deps_type=StateDeps[DocumentState], +) +app = agent.to_ag_ui(deps=StateDeps(DocumentState())) +``` + +Since `app` is an ASGI application, it can be used with any ASGI server: + +```bash +uvicorn ag_ui_state:app --host 0.0.0.0 --port 9000 +``` + +### Tools + +AG-UI frontend tools are seamlessly provided to the Pydantic AI agent, enabling rich +user experiences with frontend user interfaces. + +### Events + +Pydantic AI tools can send +[AG-UI events](https://docs.ag-ui.com/concepts/events) simply by defining a tool +which returns a (subclass of) +[`BaseEvent`](https://docs.ag-ui.com/sdk/python/core/events#baseevent), which allows +for custom events and state updates. + +```python {title="ag_ui_tool_events.py" py="3.10"} +from ag_ui.core import CustomEvent, EventType, StateSnapshotEvent +from pydantic import BaseModel + +from pydantic_ai import Agent, RunContext +from pydantic_ai.ag_ui import StateDeps + + +class DocumentState(BaseModel): + """State for the document being written.""" + + document: str = '' + + +agent = Agent( + 'openai:gpt-4.1', + instructions='Be fun!', + deps_type=StateDeps[DocumentState], +) +app = agent.to_ag_ui(deps=StateDeps(DocumentState())) + + +@agent.tool +def update_state(ctx: RunContext[StateDeps[DocumentState]]) -> StateSnapshotEvent: + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=ctx.deps.state, + ) + + +@agent.tool_plain +def custom_events() -> list[CustomEvent]: + return [ + CustomEvent( + type=EventType.CUSTOM, + name='count', + value=1, + ), + CustomEvent( + type=EventType.CUSTOM, + name='count', + value=2, + ), + ] +``` + +Since `app` is an ASGI application, it can be used with any ASGI server: + +```bash +uvicorn ag_ui_tool_events:app --host 0.0.0.0 --port 9000 +``` + +## Examples + +For more examples of how to use [`to_ag_ui()`][pydantic_ai.Agent.to_ag_ui] see +[`pydantic_ai_examples.ag_ui`](https://github.com/pydantic/pydantic-ai/tree/main/examples/pydantic_ai_examples/ag_ui), +which includes a server for use with the +[AG-UI Dojo](https://docs.ag-ui.com/tutorials/debugging#the-ag-ui-dojo). diff --git a/docs/api/ag_ui.md b/docs/api/ag_ui.md new file mode 100644 index 0000000000..bb0ffd429e --- /dev/null +++ b/docs/api/ag_ui.md @@ -0,0 +1,3 @@ +# `pydantic_ai.ag_ui` + +::: pydantic_ai.ag_ui diff --git a/docs/examples/ag-ui.md b/docs/examples/ag-ui.md new file mode 100644 index 0000000000..d893899da5 --- /dev/null +++ b/docs/examples/ag-ui.md @@ -0,0 +1,204 @@ +# Agent User Interaction (AG-UI) + +Example of using Pydantic AI agents with the [AG-UI Dojo](https://github.com/ag-ui-protocol/ag-ui/tree/main/typescript-sdk/apps/dojo) example app. + +See the [AG-UI docs](../ag-ui.md) for more information about the AG-UI integration. + +Demonstrates: + +- [AG-UI](../ag-ui.md) +- [Tools](../tools.md) + +## Prerequisites + +- An [OpenAI API key](https://help.openai.com/en/articles/4936850-where-do-i-find-my-openai-api-key) + +## Running the Example + +With [dependencies installed and environment variables set](./index.md#usage) +you will need two command line windows. + +### Pydantic AI AG-UI backend + +Setup your OpenAI API Key + +```bash +export OPENAI_API_KEY= +``` + +Start the Pydantic AI AG-UI example backend. + +```bash +python/uv-run -m pydantic_ai_examples.ag_ui +``` + +### AG-UI Dojo example frontend + +Next run the AG-UI Dojo example frontend. + +1. Clone the [AG-UI repository](https://github.com/ag-ui-protocol/ag-ui) + + ```shell + git clone https://github.com/ag-ui-protocol/ag-ui.git + ``` + +2. Change into to the `ag-ui/typescript-sdk` directory + + ```shell + cd ag-ui/typescript-sdk + ``` + +3. Run the Dojo app following the [official instructions](https://github.com/ag-ui-protocol/ag-ui/tree/main/typescript-sdk/apps/dojo#development-setup) +4. Visit +5. Select View `Pydantic AI` from the sidebar + +## Feature Examples + +### Agentic Chat + +This demonstrates a basic agent interaction including Pydantic AI server side +tools and AG-UI client side tools. + +View the [Agentic Chat example](http://localhost:3000/pydantic-ai/feature/agentic_chat). + +#### Agent Tools + +- `time` - Pydantic AI tool to check the current time for a time zone +- `background` - AG-UI tool to set the background color of the client window + +#### Agent Prompts + +```text +What is the time in New York? +``` + +```text +Change the background to blue +``` + +A complex example which mixes both AG-UI and Pydantic AI tools: + +```text +Perform the following steps, waiting for the response of each step before continuing: +1. Get the time +2. Set the background to red +3. Get the time +4. Report how long the background set took by diffing the two times +``` + +#### Agentic Chat - Code + +```snippet {path="/examples/pydantic_ai_examples/ag_ui/api/agentic_chat.py"}``` + +### Agentic Generative UI + +Demonstrates a long running task where the agent sends updates to the frontend +to let the user know what's happening. + +View the [Agentic Generative UI example](http://localhost:3000/pydantic-ai/feature/agentic_generative_ui). + +#### Plan Prompts + +```text +Create a plan for breakfast and execute it +``` + +#### Agentic Generative UI - Code + +```snippet {path="/examples/pydantic_ai_examples/ag_ui/api/agentic_generative_ui.py"}``` + +### Human in the Loop + +Demonstrates simple human in the loop workflow where the agent comes up with a +plan and the user can approve it using checkboxes. + +#### Task Planning Tools + +- `generate_task_steps` - AG-UI tool to generate and confirm steps + +#### Task Planning Prompt + +```text +Generate a list of steps for cleaning a car for me to review +``` + +#### Human in the Loop - Code + +```snippet {path="/examples/pydantic_ai_examples/ag_ui/api/human_in_the_loop.py"}``` + +### Predictive State Updates + +Demonstrates how to use the predictive state updates feature to update the state +of the UI based on agent responses, including user interaction via user +confirmation. + +View the [Predictive State Updates example](http://localhost:3000/pydantic-ai/feature/predictive_state_updates). + +#### Story Tools + +- `write_document` - AG-UI tool to write the document to a window +- `document_predict_state` - Pydantic AI tool that enables document state + prediction for the `write_document` tool + +This also shows how to use custom instructions based on shared state information. + +#### Story Example + +Starting document text + +```markdown +Bruce was a good dog, +``` + +Agent prompt + +```text +Help me complete my story about bruce the dog, is should be no longer than a sentence. +``` + +#### Predictive State Updates - Code + +```snippet {path="/examples/pydantic_ai_examples/ag_ui/api/predictive_state_updates.py"}``` + +### Shared State + +Demonstrates how to use the shared state between the UI and the agent. + +State sent to the agent is detected by a function based instruction. This then +validates the data using a custom pydantic model before using to create the +instructions for the agent to follow and send to the client using a AG-UI tool. + +View the [Shared State example](http://localhost:3000/pydantic-ai/feature/shared_state). + +#### Recipe Tools + +- `display_recipe` - AG-UI tool to display the recipe in a graphical format + +#### Recipe Example + +1. Customise the basic settings of your recipe +2. Click `Improve with AI` + +#### Shared State - Code + +```snippet {path="/examples/pydantic_ai_examples/ag_ui/api/shared_state.py"}``` + +### Tool Based Generative UI + +Demonstrates customised rendering for tool output with used confirmation. + +View the [Tool Based Generative UI example](http://localhost:3000/pydantic-ai/feature/tool_based_generative_ui). + +#### Haiku Tools + +- `generate_haiku` - AG-UI tool to display a haiku in English and Japanese + +#### Haiku Prompt + +```text +Generate a haiku about formula 1 +``` + +#### Tool Based Generative UI - Code + +```snippet {path="/examples/pydantic_ai_examples/ag_ui/api/tool_based_generative_ui.py"}``` diff --git a/docs/install.md b/docs/install.md index d1a6909c84..610b247223 100644 --- a/docs/install.md +++ b/docs/install.md @@ -56,6 +56,7 @@ pip/uv-add "pydantic-ai-slim[openai]" * `cohere` - installs `cohere` [PyPI ↗](https://pypi.org/project/cohere){:target="_blank"} * `duckduckgo` - installs `ddgs` [PyPI ↗](https://pypi.org/project/ddgs){:target="_blank"} * `tavily` - installs `tavily-python` [PyPI ↗](https://pypi.org/project/tavily-python){:target="_blank"} +* `ag-ui` - installs `ag-ui-protocol` [PyPI ↗](https://pypi.org/project/ag-ui-protocol){:target="_blank"} and `starlette` [PyPI ↗](https://pypi.org/project/starlette){:target="_blank"} See the [models](models/index.md) documentation for information on which optional dependencies are required for each model. diff --git a/docs/tools.md b/docs/tools.md index 6744b3cd13..4b40e78818 100644 --- a/docs/tools.md +++ b/docs/tools.md @@ -17,7 +17,7 @@ For more advanced use cases, the [toolsets](toolsets.md) feature lets you manage !!! info "Function tools vs. RAG" Function tools are basically the "R" of RAG (Retrieval-Augmented Generation) — they augment what the model can do by letting it request extra information. - The main semantic difference between PydanticAI Tools and RAG is RAG is synonymous with vector search, while PydanticAI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) + The main semantic difference between Pydantic AI Tools and RAG is RAG is synonymous with vector search, while Pydantic AI tools are more general-purpose. (Note: we may add support for vector search functionality in the future, particularly an API for generating embeddings. See [#58](https://github.com/pydantic/pydantic-ai/issues/58)) !!! info "Function Tools vs. Structured Outputs" As the name suggests, function tools use the model's "tools" or "functions" API to let the model know what is available to call. Tools or functions are also used to define the schema(s) for [structured output](output.md) when using the default [tool output mode](output.md#tool-output), thus a model might have access to many tools, some of which call function tools while others end the run and produce a final output. diff --git a/examples/pydantic_ai_examples/ag_ui/__init__.py b/examples/pydantic_ai_examples/ag_ui/__init__.py new file mode 100644 index 0000000000..2eb4f87e4b --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/__init__.py @@ -0,0 +1,41 @@ +"""Example usage of the AG-UI adapter for Pydantic AI. + +This provides a FastAPI application that demonstrates how to use the +Pydantic AI agent with the AG-UI protocol. It includes examples for +each of the AG-UI dojo features: +- Agentic Chat +- Human in the Loop +- Agentic Generative UI +- Tool Based Generative UI +- Shared State +- Predictive State Updates +""" + +from __future__ import annotations + +from fastapi import FastAPI + +from .api import ( + agentic_chat_app, + agentic_generative_ui_app, + human_in_the_loop_app, + predictive_state_updates_app, + shared_state_app, + tool_based_generative_ui_app, +) + +app = FastAPI(title='Pydantic AI AG-UI server') +app.mount('/agentic_chat', agentic_chat_app, 'Agentic Chat') +app.mount('/agentic_generative_ui', agentic_generative_ui_app, 'Agentic Generative UI') +app.mount('/human_in_the_loop', human_in_the_loop_app, 'Human in the Loop') +app.mount( + '/predictive_state_updates', + predictive_state_updates_app, + 'Predictive State Updates', +) +app.mount('/shared_state', shared_state_app, 'Shared State') +app.mount( + '/tool_based_generative_ui', + tool_based_generative_ui_app, + 'Tool Based Generative UI', +) diff --git a/examples/pydantic_ai_examples/ag_ui/__main__.py b/examples/pydantic_ai_examples/ag_ui/__main__.py new file mode 100644 index 0000000000..81a598a9ef --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/__main__.py @@ -0,0 +1,9 @@ +"""Very simply CLI to run the AG-UI example. + +See https://ai.pydantic.dev/examples/ag-ui/ for more information. +""" + +if __name__ == '__main__': + import uvicorn + + uvicorn.run('pydantic_ai_examples.ag_ui:app', port=9000) diff --git a/examples/pydantic_ai_examples/ag_ui/api/__init__.py b/examples/pydantic_ai_examples/ag_ui/api/__init__.py new file mode 100644 index 0000000000..2f89543177 --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/__init__.py @@ -0,0 +1,19 @@ +"""Example API for a AG-UI compatible Pydantic AI Agent UI.""" + +from __future__ import annotations + +from .agentic_chat import app as agentic_chat_app +from .agentic_generative_ui import app as agentic_generative_ui_app +from .human_in_the_loop import app as human_in_the_loop_app +from .predictive_state_updates import app as predictive_state_updates_app +from .shared_state import app as shared_state_app +from .tool_based_generative_ui import app as tool_based_generative_ui_app + +__all__ = [ + 'agentic_chat_app', + 'agentic_generative_ui_app', + 'human_in_the_loop_app', + 'predictive_state_updates_app', + 'shared_state_app', + 'tool_based_generative_ui_app', +] diff --git a/examples/pydantic_ai_examples/ag_ui/api/agentic_chat.py b/examples/pydantic_ai_examples/ag_ui/api/agentic_chat.py new file mode 100644 index 0000000000..c91a84ad2a --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/agentic_chat.py @@ -0,0 +1,25 @@ +"""Agentic Chat feature.""" + +from __future__ import annotations + +from datetime import datetime +from zoneinfo import ZoneInfo + +from pydantic_ai import Agent + +agent = Agent('openai:gpt-4o-mini') +app = agent.to_ag_ui() + + +@agent.tool_plain +async def current_time(timezone: str = 'UTC') -> str: + """Get the current time in ISO format. + + Args: + timezone: The timezone to use. + + Returns: + The current time in ISO format string. + """ + tz: ZoneInfo = ZoneInfo(timezone) + return datetime.now(tz=tz).isoformat() diff --git a/examples/pydantic_ai_examples/ag_ui/api/agentic_generative_ui.py b/examples/pydantic_ai_examples/ag_ui/api/agentic_generative_ui.py new file mode 100644 index 0000000000..44496a4159 --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/agentic_generative_ui.py @@ -0,0 +1,119 @@ +"""Agentic Generative UI feature.""" + +from __future__ import annotations + +from textwrap import dedent +from typing import Any, Literal + +from pydantic import BaseModel, Field + +from ag_ui.core import EventType, StateDeltaEvent, StateSnapshotEvent +from pydantic_ai import Agent + +StepStatus = Literal['pending', 'completed'] + + +class Step(BaseModel): + """Represents a step in a plan.""" + + description: str = Field(description='The description of the step') + status: StepStatus = Field( + default='pending', + description='The status of the step (e.g., pending, completed)', + ) + + +class Plan(BaseModel): + """Represents a plan with multiple steps.""" + + steps: list[Step] = Field(default_factory=list, description='The steps in the plan') + + +class JSONPatchOp(BaseModel): + """A class representing a JSON Patch operation (RFC 6902).""" + + op: Literal['add', 'remove', 'replace', 'move', 'copy', 'test'] = Field( + description='The operation to perform: add, remove, replace, move, copy, or test', + ) + path: str = Field(description='JSON Pointer (RFC 6901) to the target location') + value: Any = Field( + default=None, + description='The value to apply (for add, replace operations)', + ) + from_: str | None = Field( + default=None, + alias='from', + description='Source path (for move, copy operations)', + ) + + +agent = Agent( + 'openai:gpt-4o-mini', + instructions=dedent( + """ + When planning use tools only, without any other messages. + IMPORTANT: + - Use the `create_plan` tool to set the initial state of the steps + - Use the `update_plan_step` tool to update the status of each step + - Do NOT repeat the plan or summarise it in a message + - Do NOT confirm the creation or updates in a message + - Do NOT ask the user for additional information or next steps + + Only one plan can be active at a time, so do not call the `create_plan` tool + again until all the steps in current plan are completed. + """ + ), +) + + +@agent.tool_plain +def create_plan(steps: list[str]) -> StateSnapshotEvent: + """Create a plan with multiple steps. + + Args: + steps: List of step descriptions to create the plan. + + Returns: + StateSnapshotEvent containing the initial state of the steps. + """ + plan: Plan = Plan( + steps=[Step(description=step) for step in steps], + ) + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot=plan.model_dump(), + ) + + +@agent.tool_plain +def update_plan_step( + index: int, description: str | None = None, status: StepStatus | None = None +) -> StateDeltaEvent: + """Update the plan with new steps or changes. + + Args: + index: The index of the step to update. + description: The new description for the step. + status: The new status for the step. + + Returns: + StateDeltaEvent containing the changes made to the plan. + """ + changes: list[JSONPatchOp] = [] + if description is not None: + changes.append( + JSONPatchOp( + op='replace', path=f'/steps/{index}/description', value=description + ) + ) + if status is not None: + changes.append( + JSONPatchOp(op='replace', path=f'/steps/{index}/status', value=status) + ) + return StateDeltaEvent( + type=EventType.STATE_DELTA, + delta=changes, + ) + + +app = agent.to_ag_ui() diff --git a/examples/pydantic_ai_examples/ag_ui/api/human_in_the_loop.py b/examples/pydantic_ai_examples/ag_ui/api/human_in_the_loop.py new file mode 100644 index 0000000000..3f48462976 --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/human_in_the_loop.py @@ -0,0 +1,26 @@ +"""Human in the Loop Feature. + +No special handling is required for this feature. +""" + +from __future__ import annotations + +from textwrap import dedent + +from pydantic_ai import Agent + +agent = Agent( + 'openai:gpt-4o-mini', + instructions=dedent( + """ + When planning tasks use tools only, without any other messages. + IMPORTANT: + - Use the `generate_task_steps` tool to display the suggested steps to the user + - Never repeat the plan, or send a message detailing steps + - If accepted, confirm the creation of the plan and the number of selected (enabled) steps only + - If not accepted, ask the user for more information, DO NOT use the `generate_task_steps` tool again + """ + ), +) + +app = agent.to_ag_ui() diff --git a/examples/pydantic_ai_examples/ag_ui/api/predictive_state_updates.py b/examples/pydantic_ai_examples/ag_ui/api/predictive_state_updates.py new file mode 100644 index 0000000000..6769858d8d --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/predictive_state_updates.py @@ -0,0 +1,77 @@ +"""Predictive State feature.""" + +from __future__ import annotations + +from textwrap import dedent + +from pydantic import BaseModel + +from ag_ui.core import CustomEvent, EventType +from pydantic_ai import Agent, RunContext +from pydantic_ai.ag_ui import StateDeps + + +class DocumentState(BaseModel): + """State for the document being written.""" + + document: str = '' + + +agent = Agent('openai:gpt-4o-mini', deps_type=StateDeps[DocumentState]) + + +# Tools which return AG-UI events will be sent to the client as part of the +# event stream, single events and iterables of events are supported. +@agent.tool_plain +def document_predict_state() -> list[CustomEvent]: + """Enable document state prediction. + + Returns: + CustomEvent containing the event to enable state prediction. + """ + return [ + CustomEvent( + type=EventType.CUSTOM, + name='PredictState', + value=[ + { + 'state_key': 'document', + 'tool': 'write_document', + 'tool_argument': 'document', + }, + ], + ), + ] + + +@agent.instructions() +def story_instructions(ctx: RunContext[StateDeps[DocumentState]]) -> str: + """Provide instructions for writing document if present. + + Args: + ctx: The run context containing document state information. + + Returns: + Instructions string for the document writing agent. + """ + return dedent( + f"""You are a helpful assistant for writing documents. + + Before you start writing, you MUST call the `document_predict_state` + tool to enable state prediction. + + To present the document to the user for review, you MUST use the + `write_document` tool. + + When you have written the document, DO NOT repeat it as a message. + If accepted briefly summarize the changes you made, 2 sentences + max, otherwise ask the user to clarify what they want to change. + + This is the current document: + + {ctx.deps.state.document} + """ + ) + + +app = agent.to_ag_ui(deps=StateDeps(DocumentState())) diff --git a/examples/pydantic_ai_examples/ag_ui/api/shared_state.py b/examples/pydantic_ai_examples/ag_ui/api/shared_state.py new file mode 100644 index 0000000000..d3985b3622 --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/shared_state.py @@ -0,0 +1,137 @@ +"""Shared State feature.""" + +from __future__ import annotations + +from enum import StrEnum +from textwrap import dedent + +from pydantic import BaseModel, Field + +from ag_ui.core import EventType, StateSnapshotEvent +from pydantic_ai import Agent, RunContext +from pydantic_ai.ag_ui import StateDeps + + +class SkillLevel(StrEnum): + """The level of skill required for the recipe.""" + + BEGINNER = 'Beginner' + INTERMEDIATE = 'Intermediate' + ADVANCED = 'Advanced' + + +class SpecialPreferences(StrEnum): + """Special preferences for the recipe.""" + + HIGH_PROTEIN = 'High Protein' + LOW_CARB = 'Low Carb' + SPICY = 'Spicy' + BUDGET_FRIENDLY = 'Budget-Friendly' + ONE_POT_MEAL = 'One-Pot Meal' + VEGETARIAN = 'Vegetarian' + VEGAN = 'Vegan' + + +class CookingTime(StrEnum): + """The cooking time of the recipe.""" + + FIVE_MIN = '5 min' + FIFTEEN_MIN = '15 min' + THIRTY_MIN = '30 min' + FORTY_FIVE_MIN = '45 min' + SIXTY_PLUS_MIN = '60+ min' + + +class Ingredient(BaseModel): + """A class representing an ingredient in a recipe.""" + + icon: str = Field( + default='ingredient', + description="The icon emoji (not emoji code like '\x1f35e', but the actual emoji like 🥕) of the ingredient", + ) + name: str + amount: str + + +class Recipe(BaseModel): + """A class representing a recipe.""" + + skill_level: SkillLevel = Field( + default=SkillLevel.BEGINNER, + description='The skill level required for the recipe', + ) + special_preferences: list[SpecialPreferences] = Field( + default_factory=list, + description='Any special preferences for the recipe', + ) + cooking_time: CookingTime = Field( + default=CookingTime.FIVE_MIN, description='The cooking time of the recipe' + ) + ingredients: list[Ingredient] = Field( + default_factory=list, + description='Ingredients for the recipe', + ) + instructions: list[str] = Field( + default_factory=list, description='Instructions for the recipe' + ) + + +class RecipeSnapshot(BaseModel): + """A class representing the state of the recipe.""" + + recipe: Recipe = Field( + default_factory=Recipe, description='The current state of the recipe' + ) + + +agent = Agent('openai:gpt-4o-mini', deps_type=StateDeps[RecipeSnapshot]) + + +@agent.tool_plain +def display_recipe(recipe: Recipe) -> StateSnapshotEvent: + """Display the recipe to the user. + + Args: + recipe: The recipe to display. + + Returns: + StateSnapshotEvent containing the recipe snapshot. + """ + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot={'recipe': recipe}, + ) + + +@agent.instructions +def recipe_instructions(ctx: RunContext[StateDeps[RecipeSnapshot]]) -> str: + """Instructions for the recipe generation agent. + + Args: + ctx: The run context containing recipe state information. + + Returns: + Instructions string for the recipe generation agent. + """ + return dedent( + f""" + You are a helpful assistant for creating recipes. + + IMPORTANT: + - Create a complete recipe using the existing ingredients + - Append new ingredients to the existing ones + - Use the `display_recipe` tool to present the recipe to the user + - Do NOT repeat the recipe in the message, use the tool instead + + Once you have created the updated recipe and displayed it to the user, + summarise the changes in one sentence, don't describe the recipe in + detail or send it as a message to the user. + + The current state of the recipe is: + + {ctx.deps.state.recipe.model_dump_json(indent=2)} + """, + ) + + +app = agent.to_ag_ui(deps=StateDeps(RecipeSnapshot())) diff --git a/examples/pydantic_ai_examples/ag_ui/api/tool_based_generative_ui.py b/examples/pydantic_ai_examples/ag_ui/api/tool_based_generative_ui.py new file mode 100644 index 0000000000..88dfee0437 --- /dev/null +++ b/examples/pydantic_ai_examples/ag_ui/api/tool_based_generative_ui.py @@ -0,0 +1,11 @@ +"""Tool Based Generative UI feature. + +No special handling is required for this feature. +""" + +from __future__ import annotations + +from pydantic_ai import Agent + +agent = Agent('openai:gpt-4o-mini') +app = agent.to_ag_ui() diff --git a/examples/pydantic_ai_examples/py.typed b/examples/pydantic_ai_examples/py.typed new file mode 100644 index 0000000000..e69de29bb2 diff --git a/examples/pyproject.toml b/examples/pyproject.toml index 29720813b9..e1e5252e3b 100644 --- a/examples/pyproject.toml +++ b/examples/pyproject.toml @@ -48,7 +48,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,groq,anthropic]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,groq,anthropic,ag-ui]=={{ version }}", "pydantic-evals=={{ version }}", "asyncpg>=0.30.0", "fastapi>=0.115.4", diff --git a/mkdocs.yml b/mkdocs.yml index 796d6601f5..1860a00c0e 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -49,9 +49,11 @@ nav: - mcp/server.md - mcp/run-python.md - A2A: a2a.md + - AG-UI: ag-ui.md - cli.md - Examples: - examples/index.md + - examples/ag-ui.md - examples/pydantic-model.md - examples/weather-agent.md - examples/bank-support.md @@ -63,7 +65,9 @@ nav: - examples/chat-app.md - examples/question-graph.md - examples/slack-lead-qualifier.md + - API Reference: + - api/ag_ui.md - api/agent.md - api/tools.md - api/toolsets.md diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py new file mode 100644 index 0000000000..a43d8bda49 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -0,0 +1,675 @@ +"""Provides an AG-UI protocol adapter for the Pydantic AI agent. + +This package provides seamless integration between pydantic-ai agents and ag-ui +for building interactive AI applications with streaming event-based communication. +""" + +from __future__ import annotations + +import json +import uuid +from collections.abc import Iterable, Mapping, Sequence +from dataclasses import dataclass, field +from http import HTTPStatus +from typing import ( + Any, + Callable, + Final, + Generic, + Protocol, + TypeVar, + runtime_checkable, +) + +try: + from ag_ui.core import ( + AssistantMessage, + BaseEvent, + DeveloperMessage, + EventType, + Message, + RunAgentInput, + RunErrorEvent, + RunFinishedEvent, + RunStartedEvent, + State, + SystemMessage, + TextMessageContentEvent, + TextMessageEndEvent, + TextMessageStartEvent, + ThinkingTextMessageContentEvent, + ThinkingTextMessageEndEvent, + ThinkingTextMessageStartEvent, + ToolCallArgsEvent, + ToolCallEndEvent, + ToolCallResultEvent, + ToolCallStartEvent, + ToolMessage, + UserMessage, + ) + from ag_ui.encoder import EventEncoder +except ImportError as e: # pragma: no cover + raise ImportError( + 'Please install the `ag-ui-protocol` package to use `Agent.to_ag_ui()` method, ' + 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' + ) from e + +try: + from starlette.applications import Starlette + from starlette.middleware import Middleware + from starlette.requests import Request + from starlette.responses import Response, StreamingResponse + from starlette.routing import BaseRoute + from starlette.types import ExceptionHandler, Lifespan +except ImportError as e: # pragma: no cover + raise ImportError( + 'Please install the `starlette` package to use `Agent.to_ag_ui()` method, ' + 'you can use the `ag-ui` optional group — `pip install "pydantic-ai-slim[ag-ui]"`' + ) from e + +from collections.abc import AsyncGenerator + +from pydantic import BaseModel, ValidationError + +from ._agent_graph import CallToolsNode, ModelRequestNode +from .agent import Agent, AgentRun, RunOutputDataT +from .messages import ( + AgentStreamEvent, + FunctionToolResultEvent, + ModelMessage, + ModelRequest, + ModelResponse, + PartDeltaEvent, + PartStartEvent, + SystemPromptPart, + TextPart, + TextPartDelta, + ThinkingPart, + ThinkingPartDelta, + ToolCallPart, + ToolCallPartDelta, + ToolReturnPart, + UserPromptPart, +) +from .models import KnownModelName, Model +from .output import DeferredToolCalls, OutputDataT, OutputSpec +from .settings import ModelSettings +from .tools import AgentDepsT, ToolDefinition +from .toolsets import AbstractToolset +from .toolsets.deferred import DeferredToolset +from .usage import Usage, UsageLimits + +__all__ = [ + 'SSE_CONTENT_TYPE', + 'StateDeps', + 'StateHandler', + 'AGUIApp', +] + +SSE_CONTENT_TYPE: Final[str] = 'text/event-stream' +"""Content type header value for Server-Sent Events (SSE).""" + + +class AGUIApp(Generic[AgentDepsT, OutputDataT], Starlette): + """ASGI application for running Pydantic AI agents with AG-UI protocol support.""" + + def __init__( + self, + agent: Agent[AgentDepsT, OutputDataT], + *, + # Agent.iter parameters. + output_type: OutputSpec[OutputDataT] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette parameters. + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, + ) -> None: + """Initialise the AG-UI application. + + Args: + agent: The Pydantic AI `Agent` to adapt. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + """ + super().__init__( + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + adapter = _Adapter(agent=agent) + + async def endpoint(request: Request) -> Response | StreamingResponse: + """Endpoint to run the agent with the provided input data.""" + accept = request.headers.get('accept', SSE_CONTENT_TYPE) + try: + input_data = RunAgentInput.model_validate(await request.json()) + except ValidationError as e: # pragma: no cover + return Response( + content=json.dumps(e.json()), + media_type='application/json', + status_code=HTTPStatus.UNPROCESSABLE_ENTITY, + ) + + return StreamingResponse( + adapter.run( + input_data, + accept, + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ), + media_type=SSE_CONTENT_TYPE, + ) + + self.router.add_route('/', endpoint, methods=['POST'], name='run_agent') + + +@dataclass(repr=False) +class _Adapter(Generic[AgentDepsT, OutputDataT]): + """An agent adapter providing AG-UI protocol support for Pydantic AI agents. + + This class manages the agent runs, tool calls, state storage and providing + an adapter for running agents with Server-Sent Event (SSE) streaming + responses using the AG-UI protocol. + + Args: + agent: The Pydantic AI `Agent` to adapt. + """ + + agent: Agent[AgentDepsT, OutputDataT] = field(repr=False) + + async def run( + self, + run_input: RunAgentInput, + accept: str = SSE_CONTENT_TYPE, + *, + output_type: OutputSpec[RunOutputDataT] | None = None, + model: Model | KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + ) -> AsyncGenerator[str, None]: + """Run the agent with streaming response using AG-UI protocol events. + + The first two arguments are specific to `Adapter` the rest map directly to the `Agent.iter` method. + + Args: + run_input: The AG-UI run input containing thread_id, run_id, messages, etc. + accept: The accept header value for the run. + + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has no + output validators since output validators would expect an argument that matches the agent's output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + Yields: + Streaming SSE-formatted event chunks. + """ + encoder = EventEncoder(accept=accept) + if run_input.tools: + # AG-UI tools can't be prefixed as that would result in a mismatch between the tool names in the + # Pydantic AI events and actual AG-UI tool names, preventing the tool from being called. If any + # conflicts arise, the AG-UI tool should be renamed or a `PrefixedToolset` used for local toolsets. + toolset = DeferredToolset[AgentDepsT]( + [ + ToolDefinition( + name=tool.name, + description=tool.description, + parameters_json_schema=tool.parameters, + ) + for tool in run_input.tools + ] + ) + toolsets = [*toolsets, toolset] if toolsets else [toolset] + + try: + yield encoder.encode( + RunStartedEvent( + thread_id=run_input.thread_id, + run_id=run_input.run_id, + ), + ) + + if not run_input.messages: + raise _NoMessagesError + + if isinstance(deps, StateHandler): + deps.state = run_input.state + + history = _History.from_ag_ui(run_input.messages) + + async with self.agent.iter( + user_prompt=None, + output_type=[output_type or self.agent.output_type, DeferredToolCalls], + message_history=history.messages, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + ) as run: + async for event in self._agent_stream(run, history): + yield encoder.encode(event) + except _RunError as e: + yield encoder.encode( + RunErrorEvent(message=e.message, code=e.code), + ) + except Exception as e: # pragma: no cover + yield encoder.encode( + RunErrorEvent(message=str(e)), + ) + raise e + else: + yield encoder.encode( + RunFinishedEvent( + thread_id=run_input.thread_id, + run_id=run_input.run_id, + ), + ) + + async def _agent_stream( + self, + run: AgentRun[AgentDepsT, Any], + history: _History, + ) -> AsyncGenerator[BaseEvent, None]: + """Run the agent streaming responses using AG-UI protocol events. + + Args: + run: The agent run to process. + history: The history of messages and tool calls to use for the run. + + Yields: + AG-UI Server-Sent Events (SSE). + """ + async for node in run: + if isinstance(node, ModelRequestNode): + stream_ctx = _RequestStreamContext() + async with node.stream(run.ctx) as request_stream: + async for agent_event in request_stream: + async for msg in self._handle_model_request_event(stream_ctx, agent_event): + yield msg + + if stream_ctx.part_end: # pragma: no branch + yield stream_ctx.part_end + stream_ctx.part_end = None + elif isinstance(node, CallToolsNode): + async with node.stream(run.ctx) as handle_stream: + async for event in handle_stream: + if isinstance(event, FunctionToolResultEvent) and isinstance(event.result, ToolReturnPart): + async for msg in self._handle_tool_result_event(event.result, history.prompt_message_id): + yield msg + + async def _handle_model_request_event( + self, + stream_ctx: _RequestStreamContext, + agent_event: AgentStreamEvent, + ) -> AsyncGenerator[BaseEvent, None]: + """Handle an agent event and yield AG-UI protocol events. + + Args: + stream_ctx: The request stream context to manage state. + agent_event: The agent event to process. + + Yields: + AG-UI Server-Sent Events (SSE) based on the agent event. + """ + if isinstance(agent_event, PartStartEvent): + if stream_ctx.part_end: + # End the previous part. + yield stream_ctx.part_end + stream_ctx.part_end = None + + part = agent_event.part + if isinstance(part, TextPart): + message_id = stream_ctx.new_message_id() + yield TextMessageStartEvent( + message_id=message_id, + ) + stream_ctx.part_end = TextMessageEndEvent( + message_id=message_id, + ) + if part.content: # pragma: no branch + yield TextMessageContentEvent( + message_id=message_id, + delta=part.content, + ) + elif isinstance(part, ToolCallPart): # pragma: no branch + yield ToolCallStartEvent( + tool_call_id=part.tool_call_id, + tool_call_name=part.tool_name, + ) + stream_ctx.part_end = ToolCallEndEvent( + tool_call_id=part.tool_call_id, + ) + + elif isinstance(part, ThinkingPart): # pragma: no branch + yield ThinkingTextMessageStartEvent( + type=EventType.THINKING_TEXT_MESSAGE_START, + ) + # Always send the content even if it's empty, as it may be + # used to indicate the start of thinking. + yield ThinkingTextMessageContentEvent( + type=EventType.THINKING_TEXT_MESSAGE_CONTENT, + delta=part.content or '', + ) + stream_ctx.part_end = ThinkingTextMessageEndEvent( + type=EventType.THINKING_TEXT_MESSAGE_END, + ) + + elif isinstance(agent_event, PartDeltaEvent): + delta = agent_event.delta + if isinstance(delta, TextPartDelta): + yield TextMessageContentEvent( + message_id=stream_ctx.message_id, + delta=delta.content_delta, + ) + elif isinstance(delta, ToolCallPartDelta): # pragma: no branch + assert delta.tool_call_id, '`ToolCallPartDelta.tool_call_id` must be set' + yield ToolCallArgsEvent( + tool_call_id=delta.tool_call_id, + delta=delta.args_delta if isinstance(delta.args_delta, str) else json.dumps(delta.args_delta), + ) + elif isinstance(delta, ThinkingPartDelta): # pragma: no branch + if delta.content_delta: # pragma: no branch + yield ThinkingTextMessageContentEvent( + type=EventType.THINKING_TEXT_MESSAGE_CONTENT, + delta=delta.content_delta, + ) + + async def _handle_tool_result_event( + self, + result: ToolReturnPart, + prompt_message_id: str, + ) -> AsyncGenerator[BaseEvent, None]: + """Convert a tool call result to AG-UI events. + + Args: + result: The tool call result to process. + prompt_message_id: The message ID of the prompt that initiated the tool call. + + Yields: + AG-UI Server-Sent Events (SSE). + """ + yield ToolCallResultEvent( + message_id=prompt_message_id, + type=EventType.TOOL_CALL_RESULT, + role='tool', + tool_call_id=result.tool_call_id, + content=result.model_response_str(), + ) + + # Now check for AG-UI events returned by the tool calls. + content = result.content + if isinstance(content, BaseEvent): + yield content + elif isinstance(content, (str, bytes)): # pragma: no branch + # Avoid iterable check for strings and bytes. + pass + elif isinstance(content, Iterable): # pragma: no branch + for item in content: # type: ignore[reportUnknownMemberType] + if isinstance(item, BaseEvent): # pragma: no branch + yield item + + +@dataclass +class _History: + """A simple history representation for AG-UI protocol.""" + + prompt_message_id: str # The ID of the last user message. + messages: list[ModelMessage] + + @classmethod + def from_ag_ui(cls, messages: list[Message]) -> _History: + """Convert a AG-UI history to a Pydantic AI one. + + Args: + messages: List of AG-UI messages to convert. + + Returns: + List of Pydantic AI model messages. + """ + prompt_message_id = '' + result: list[ModelMessage] = [] + tool_calls: dict[str, str] = {} # Tool call ID to tool name mapping. + for msg in messages: + if isinstance(msg, UserMessage): + prompt_message_id = msg.id + result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)])) + elif isinstance(msg, AssistantMessage): + if msg.tool_calls: + for tool_call in msg.tool_calls: + tool_calls[tool_call.id] = tool_call.function.name + + result.append( + ModelResponse( + parts=[ + ToolCallPart( + tool_name=tool_call.function.name, + tool_call_id=tool_call.id, + args=tool_call.function.arguments, + ) + for tool_call in msg.tool_calls + ] + ) + ) + + if msg.content: + result.append(ModelResponse(parts=[TextPart(content=msg.content)])) + elif isinstance(msg, SystemMessage): + result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)])) + elif isinstance(msg, ToolMessage): + tool_name = tool_calls.get(msg.tool_call_id) + if tool_name is None: # pragma: no cover + raise _ToolCallNotFoundError(tool_call_id=msg.tool_call_id) + + result.append( + ModelRequest( + parts=[ + ToolReturnPart( + tool_name=tool_name, + content=msg.content, + tool_call_id=msg.tool_call_id, + ) + ] + ) + ) + elif isinstance(msg, DeveloperMessage): # pragma: no branch + result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)])) + + return cls( + prompt_message_id=prompt_message_id, + messages=result, + ) + + +@runtime_checkable +class StateHandler(Protocol): + """Protocol for state handlers in agent runs.""" + + @property + def state(self) -> State: + """Get the current state of the agent run.""" + ... + + @state.setter + def state(self, state: State) -> None: + """Set the state of the agent run. + + This method is called to update the state of the agent run with the + provided state. + + Args: + state: The run state. + + Raises: + InvalidStateError: If `state` does not match the expected model. + """ + ... + + +StateT = TypeVar('StateT', bound=BaseModel) +"""Type variable for the state type, which must be a subclass of `BaseModel`.""" + + +class StateDeps(Generic[StateT]): + """Provides AG-UI state management. + + This class is used to manage the state of an agent run. It allows setting + the state of the agent run with a specific type of state model, which must + be a subclass of `BaseModel`. + + The state is set using the `state` setter by the `Adapter` when the run starts. + + Implements the `StateHandler` protocol. + """ + + def __init__(self, default: StateT) -> None: + """Initialize the state with the provided state type.""" + self._state = default + + @property + def state(self) -> StateT: + """Get the current state of the agent run. + + Returns: + The current run state. + """ + return self._state + + @state.setter + def state(self, state: State) -> None: + """Set the state of the agent run. + + This method is called to update the state of the agent run with the + provided state. + + Implements the `StateHandler` protocol. + + Args: + state: The run state, which must be `None` or model validate for the state type. + + Raises: + InvalidStateError: If `state` does not validate. + """ + if state is None: + # If state is None, we keep the current state, which will be the default state. + return + + try: + self._state = type(self._state).model_validate(state) + except ValidationError as e: # pragma: no cover + raise _InvalidStateError from e + + +@dataclass(repr=False) +class _RequestStreamContext: + """Data class to hold request stream context.""" + + message_id: str = '' + part_end: BaseEvent | None = None + + def new_message_id(self) -> str: + """Generate a new message ID for the request stream. + + Assigns a new UUID to the `message_id` and returns it. + + Returns: + A new message ID. + """ + self.message_id = str(uuid.uuid4()) + return self.message_id + + +@dataclass +class _RunError(Exception): + """Exception raised for errors during agent runs.""" + + message: str + code: str + + def __str__(self) -> str: # pragma: no cover + return self.message + + +@dataclass +class _NoMessagesError(_RunError): + """Exception raised when no messages are found in the input.""" + + message: str = 'no messages found in the input' + code: str = 'no_messages' + + +@dataclass +class _InvalidStateError(_RunError, ValidationError): + """Exception raised when an invalid state is provided.""" + + message: str = 'invalid state provided' + code: str = 'invalid_state' + + +class _ToolCallNotFoundError(_RunError, ValueError): + """Exception raised when an tool result is present without a matching call.""" + + def __init__(self, tool_call_id: str) -> None: + """Initialize the exception with the tool call ID.""" + super().__init__( # pragma: no cover + message=f'Tool call with ID {tool_call_id} not found in the history.', + code='tool_call_not_found', + ) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index d0275bcbbf..b2e1667ddc 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -5,7 +5,7 @@ import json import warnings from asyncio import Lock -from collections.abc import AsyncIterator, Awaitable, Iterator, Sequence +from collections.abc import AsyncIterator, Awaitable, Iterator, Mapping, Sequence from contextlib import AbstractAsyncContextManager, AsyncExitStack, asynccontextmanager, contextmanager from contextvars import ContextVar from copy import deepcopy @@ -55,6 +55,7 @@ from .toolsets.combined import CombinedToolset from .toolsets.function import FunctionToolset from .toolsets.prepared import PreparedToolset +from .usage import Usage, UsageLimits # Re-exporting like this improves auto-import behavior in PyCharm capture_run_messages = _agent_graph.capture_run_messages @@ -69,11 +70,12 @@ from fasta2a.schema import AgentProvider, Skill from fasta2a.storage import Storage from starlette.middleware import Middleware - from starlette.routing import Route + from starlette.routing import BaseRoute, Route from starlette.types import ExceptionHandler, Lifespan from pydantic_ai.mcp import MCPServer + from .ag_ui import AGUIApp __all__ = ( 'Agent', @@ -1863,6 +1865,105 @@ async def run_mcp_servers( async with self: yield + def to_ag_ui( + self, + *, + # Agent.iter parameters + output_type: OutputSpec[OutputDataT] | None = None, + model: models.Model | models.KnownModelName | str | None = None, + deps: AgentDepsT = None, + model_settings: ModelSettings | None = None, + usage_limits: UsageLimits | None = None, + usage: Usage | None = None, + infer_name: bool = True, + toolsets: Sequence[AbstractToolset[AgentDepsT]] | None = None, + # Starlette + debug: bool = False, + routes: Sequence[BaseRoute] | None = None, + middleware: Sequence[Middleware] | None = None, + exception_handlers: Mapping[Any, ExceptionHandler] | None = None, + on_startup: Sequence[Callable[[], Any]] | None = None, + on_shutdown: Sequence[Callable[[], Any]] | None = None, + lifespan: Lifespan[AGUIApp[AgentDepsT, OutputDataT]] | None = None, + ) -> AGUIApp[AgentDepsT, OutputDataT]: + """Convert the agent to an AG-UI application. + + This allows you to use the agent with a compatible AG-UI frontend. + + Example: + ```python + from pydantic_ai import Agent + + agent = Agent('openai:gpt-4o') + app = agent.to_ag_ui() + ``` + + The `app` is an ASGI application that can be used with any ASGI server. + + To run the application, you can use the following command: + + ```bash + uvicorn app:app --host 0.0.0.0 --port 8000 + ``` + + See [AG-UI docs](../ag-ui.md) for more information. + + Args: + output_type: Custom output type to use for this run, `output_type` may only be used if the agent has + no output validators since output validators would expect an argument that matches the agent's + output type. + model: Optional model to use for this run, required if `model` was not set when creating the agent. + deps: Optional dependencies to use for this run. + model_settings: Optional settings to use for this model's request. + usage_limits: Optional limits on model request count or token usage. + usage: Optional usage to start with, useful for resuming a conversation or agents used in tools. + infer_name: Whether to try to infer the agent name from the call frame if it's not set. + toolsets: Optional list of toolsets to use for this agent, defaults to the agent's toolset. + + debug: Boolean indicating if debug tracebacks should be returned on errors. + routes: A list of routes to serve incoming HTTP and WebSocket requests. + middleware: A list of middleware to run for every request. A starlette application will always + automatically include two middleware classes. `ServerErrorMiddleware` is added as the very + outermost middleware, to handle any uncaught errors occurring anywhere in the entire stack. + `ExceptionMiddleware` is added as the very innermost middleware, to deal with handled + exception cases occurring in the routing or endpoints. + exception_handlers: A mapping of either integer status codes, or exception class types onto + callables which handle the exceptions. Exception handler callables should be of the form + `handler(request, exc) -> response` and may be either standard functions, or async functions. + on_startup: A list of callables to run on application startup. Startup handler callables do not + take any arguments, and may be either standard functions, or async functions. + on_shutdown: A list of callables to run on application shutdown. Shutdown handler callables do + not take any arguments, and may be either standard functions, or async functions. + lifespan: A lifespan context function, which can be used to perform startup and shutdown tasks. + This is a newer style that replaces the `on_startup` and `on_shutdown` handlers. Use one or + the other, not both. + + Returns: + An ASGI application for running Pydantic AI agents with AG-UI protocol support. + """ + from .ag_ui import AGUIApp + + return AGUIApp( + agent=self, + # Agent.iter parameters + output_type=output_type, + model=model, + deps=deps, + model_settings=model_settings, + usage_limits=usage_limits, + usage=usage, + infer_name=infer_name, + toolsets=toolsets, + # Starlette + debug=debug, + routes=routes, + middleware=middleware, + exception_handlers=exception_handlers, + on_startup=on_startup, + on_shutdown=on_shutdown, + lifespan=lifespan, + ) + def to_a2a( self, *, diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 0a1f0f3cc0..e8476b554f 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -214,21 +214,39 @@ class DeltaToolCall: """Incremental change to the tool call ID.""" +@dataclass +class DeltaThinkingPart: + """Incremental change to a thinking part. + + Used to describe a chunk when streaming thinking responses. + """ + + content: str | None = None + """Incremental change to the thinking content.""" + signature: str | None = None + """Incremental change to the thinking signature.""" + + DeltaToolCalls: TypeAlias = dict[int, DeltaToolCall] """A mapping of tool call IDs to incremental changes.""" +DeltaThinkingCalls: TypeAlias = dict[int, DeltaThinkingPart] +"""A mapping of thinking call IDs to incremental changes.""" + # TODO: Change the signature to Callable[[list[ModelMessage], ModelSettings, ModelRequestParameters], ...] FunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], Union[ModelResponse, Awaitable[ModelResponse]]] """A function used to generate a non-streamed response.""" # TODO: Change signature as indicated above -StreamFunctionDef: TypeAlias = Callable[[list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls]]] +StreamFunctionDef: TypeAlias = Callable[ + [list[ModelMessage], AgentInfo], AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]] +] """A function used to generate a streamed response. -While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls]]`, it should -really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls]`, +While this is defined as having return type of `AsyncIterator[Union[str, DeltaToolCalls, DeltaThinkingCalls]]`, it should +really be considered as `Union[AsyncIterator[str], AsyncIterator[DeltaToolCalls], AsyncIterator[DeltaThinkingCalls]]`, -E.g. you need to yield all text or all `DeltaToolCalls`, not mix them. +E.g. you need to yield all text, all `DeltaToolCalls`, or all `DeltaThinkingCalls`, not mix them. """ @@ -237,7 +255,7 @@ class FunctionStreamedResponse(StreamedResponse): """Implementation of `StreamedResponse` for [FunctionModel][pydantic_ai.models.function.FunctionModel].""" _model_name: str - _iter: AsyncIterator[str | DeltaToolCalls] + _iter: AsyncIterator[str | DeltaToolCalls | DeltaThinkingCalls] _timestamp: datetime = field(default_factory=_utils.now_utc) def __post_init__(self): @@ -249,20 +267,31 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: response_tokens = _estimate_string_tokens(item) self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) - else: - delta_tool_calls = item - for dtc_index, delta_tool_call in delta_tool_calls.items(): - if delta_tool_call.json_args: - response_tokens = _estimate_string_tokens(delta_tool_call.json_args) - self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) - maybe_event = self._parts_manager.handle_tool_call_delta( - vendor_part_id=dtc_index, - tool_name=delta_tool_call.name, - args=delta_tool_call.json_args, - tool_call_id=delta_tool_call.tool_call_id, - ) - if maybe_event is not None: - yield maybe_event + elif isinstance(item, dict) and item: + for dtc_index, delta in item.items(): + if isinstance(delta, DeltaThinkingPart): + if delta.content: # pragma: no branch + response_tokens = _estimate_string_tokens(delta.content) + self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + yield self._parts_manager.handle_thinking_delta( + vendor_part_id=dtc_index, + content=delta.content, + signature=delta.signature, + ) + elif isinstance(delta, DeltaToolCall): + if delta.json_args: + response_tokens = _estimate_string_tokens(delta.json_args) + self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) + maybe_event = self._parts_manager.handle_tool_call_delta( + vendor_part_id=dtc_index, + tool_name=delta.name, + args=delta.json_args, + tool_call_id=delta.tool_call_id, + ) + if maybe_event is not None: + yield maybe_event + else: + assert_never(delta) @property def model_name(self) -> str: @@ -299,12 +328,9 @@ def _estimate_usage(messages: Iterable[ModelMessage]) -> usage.Usage: if isinstance(part, TextPart): response_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ThinkingPart): - # NOTE: We don't send ThinkingPart to the providers yet. - # If you are unsatisfied with this, please open an issue. - pass + response_tokens += _estimate_string_tokens(part.content) elif isinstance(part, ToolCallPart): - call = part - response_tokens += 1 + _estimate_string_tokens(call.args_as_json_str()) + response_tokens += 1 + _estimate_string_tokens(part.args_as_json_str()) else: assert_never(part) else: diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 05b90e3142..46b6278263 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -428,7 +428,7 @@ def _get_python_type(cls, value: dict[str, Any]) -> str: if value_type == 'object': additional_properties = value.get('additionalProperties', {}) if isinstance(additional_properties, bool): - return 'bool' # pragma: no cover + return 'bool' # pragma: lax no cover additional_properties_type = additional_properties.get('type') if ( additional_properties_type in SIMPLE_JSON_TYPE_MAPPING diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 279b0e8f87..3dfc4a7660 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -82,6 +82,8 @@ mcp = ["mcp>=1.9.4; python_version >= '3.10'"] evals = ["pydantic-evals=={{ version }}"] # A2A a2a = ["fasta2a>=0.4.1"] +# AG-UI +ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index c3d4ecc60e..841f186ef2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/tests/conftest.py b/tests/conftest.py index f94f5f0477..3ae576c630 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,6 +2,7 @@ import asyncio import importlib.util +import logging import os import re import secrets @@ -28,6 +29,11 @@ __all__ = 'IsDatetime', 'IsFloat', 'IsNow', 'IsStr', 'IsInt', 'IsInstance', 'TestEnv', 'ClientWithHandler', 'try_import' +# Configure VCR logger to WARNING as it is too verbose by default +# specifically, it logs every request and response including binary +# content in Cassette.append, which is causing log downloads from +# GitHub action to fail. +logging.getLogger('vcr.cassette').setLevel(logging.WARNING) pydantic_ai.models.ALLOW_MODEL_REQUESTS = False diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py new file mode 100644 index 0000000000..e14ead6626 --- /dev/null +++ b/tests/test_ag_ui.py @@ -0,0 +1,939 @@ +"""Tests for AG-UI implementation.""" + +# pyright: reportPossiblyUnboundVariable=none +from __future__ import annotations + +import contextlib +import json +import uuid +from collections.abc import AsyncIterator +from http import HTTPStatus +from typing import Any + +import httpx +import pytest +from asgi_lifespan import LifespanManager +from inline_snapshot import snapshot +from pydantic import BaseModel + +from pydantic_ai.agent import Agent +from pydantic_ai.messages import ModelMessage +from pydantic_ai.models.function import ( + AgentInfo, + DeltaThinkingCalls, + DeltaThinkingPart, + DeltaToolCall, + DeltaToolCalls, + FunctionModel, +) +from pydantic_ai.output import OutputDataT +from pydantic_ai.tools import AgentDepsT + +from .conftest import IsStr + +has_ag_ui: bool = False +with contextlib.suppress(ImportError): + from ag_ui.core import ( + AssistantMessage, + CustomEvent, + DeveloperMessage, + EventType, + FunctionCall, + Message, + RunAgentInput, + StateSnapshotEvent, + SystemMessage, + Tool, + ToolCall, + ToolMessage, + UserMessage, + ) + from ag_ui.encoder import EventEncoder + + from pydantic_ai.ag_ui import ( + SSE_CONTENT_TYPE, + StateDeps, + _Adapter, # type: ignore[reportPrivateUsage] + ) + + has_ag_ui = True + + +pytestmark = [ + pytest.mark.anyio, + pytest.mark.skipif(not has_ag_ui, reason='ag-ui-protocol not installed'), +] + +SIMPLE_RESULT: Any = snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': 'success '}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': '(no tool calls)'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] +) + + +async def collect_events_from_adapter( + adapter: _Adapter[AgentDepsT, OutputDataT], *run_inputs: RunAgentInput, deps: AgentDepsT = None +) -> list[dict[str, Any]]: + """Helper function to collect events from an AG-UI adapter run.""" + events = list[dict[str, Any]]() + for run_input in run_inputs: + async for event in adapter.run(run_input, deps=deps): + events.append(json.loads(event.removeprefix('data: '))) + return events + + +class StateInt(BaseModel): + """Example state class for testing purposes.""" + + value: int = 0 + + +def get_weather(name: str = 'get_weather') -> Tool: + return Tool( + name=name, + description='Get the weather for a given location', + parameters={ + 'type': 'object', + 'properties': { + 'location': { + 'type': 'string', + 'description': 'The location to get the weather for', + }, + }, + 'required': ['location'], + }, + ) + + +def current_time() -> str: + """Get the current time in ISO format. + + Returns: + The current UTC time in ISO format string. + """ + return '2023-06-21T12:08:45.485981+00:00' + + +async def send_snapshot() -> StateSnapshotEvent: + """Display the recipe to the user. + + Returns: + StateSnapshotEvent. + """ + return StateSnapshotEvent( + type=EventType.STATE_SNAPSHOT, + snapshot={'key': 'value'}, + ) + + +async def send_custom() -> list[CustomEvent]: + """Display the recipe to the user. + + Returns: + StateSnapshotEvent. + """ + return [ + CustomEvent( + type=EventType.CUSTOM, + name='custom_event1', + value={'key1': 'value1'}, + ), + CustomEvent( + type=EventType.CUSTOM, + name='custom_event2', + value={'key2': 'value2'}, + ), + ] + + +def uuid_str() -> str: + """Generate a random UUID string.""" + return uuid.uuid4().hex + + +def create_input( + *messages: Message, tools: list[Tool] | None = None, thread_id: str | None = None, state: Any = None +) -> RunAgentInput: + """Create a RunAgentInput for testing.""" + thread_id = thread_id or uuid_str() + return RunAgentInput( + thread_id=thread_id, + run_id=uuid_str(), + messages=list(messages), + state=state, + context=[], + tools=tools or [], + forwarded_props=None, + ) + + +async def simple_stream(messages: list[ModelMessage], agent_info: AgentInfo) -> AsyncIterator[str]: + """A simple function that returns a text response without tool calls.""" + yield 'success ' + yield '(no tool calls)' + + +async def test_basic_user_message() -> None: + """Test basic user message with text response.""" + agent = Agent( + model=FunctionModel(stream_function=simple_stream), + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Hello, how are you?', + ) + ) + + events = await collect_events_from_adapter(adapter, run_input) + + assert events == SIMPLE_RESULT + + +async def test_empty_messages() -> None: + """Test handling of empty messages.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[str]: # pragma: no cover + raise NotImplementedError + yield 'no messages' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + ) + adapter = _Adapter(agent=agent) + run_input = create_input() + events = await collect_events_from_adapter(adapter, run_input) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'RUN_ERROR', 'message': 'no messages found in the input', 'code': 'no_messages'}, + ] + ) + + +async def test_multiple_messages() -> None: + """Test with multiple different message types.""" + agent = Agent( + model=FunctionModel(stream_function=simple_stream), + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='First message', + ), + AssistantMessage( + id='msg_2', + content='Assistant response', + ), + SystemMessage( + id='msg_3', + content='System message', + ), + DeveloperMessage( + id='msg_4', + content='Developer note', + ), + UserMessage( + id='msg_5', + content='Second message', + ), + ) + + events = await collect_events_from_adapter(adapter, run_input) + + assert events == SIMPLE_RESULT + + +async def test_messages_with_history() -> None: + """Test with multiple user messages (conversation history).""" + agent = Agent( + model=FunctionModel(stream_function=simple_stream), + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='First message', + ), + UserMessage( + id='msg_2', + content='Second message', + ), + ) + + events = await collect_events_from_adapter(adapter, run_input) + + assert events == SIMPLE_RESULT + + +async def test_tool_ag_ui() -> None: + """Test AG-UI tool call.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call + yield {0: DeltaToolCall(name='get_weather')} + yield {0: DeltaToolCall(json_args='{"location": "Paris"}')} + else: + # Second call - return text result + yield '{"get_weather": "Tool result"}' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[send_snapshot, send_custom, current_time], + ) + adapter = _Adapter(agent=agent) + thread_id = uuid_str() + run_inputs = [ + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather for Paris', + ), + tools=[get_weather()], + thread_id=thread_id, + ), + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather for Paris', + ), + AssistantMessage( + id='msg_2', + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='get_weather', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + thread_id=thread_id, + ), + ] + + events = await collect_events_from_adapter(adapter, *run_inputs) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'get_weather'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{"location": "Paris"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': '{"get_weather": "Tool result"}'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_tool_ag_ui_multiple() -> None: + """Test multiple AG-UI tool calls in sequence.""" + run_count = 0 + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + nonlocal run_count + run_count += 1 + + if run_count == 1: + # First run - make multiple tool calls + yield {0: DeltaToolCall(name='get_weather')} + yield {0: DeltaToolCall(json_args='{"location": "Paris"}')} + yield {1: DeltaToolCall(name='get_weather_parts')} + yield {1: DeltaToolCall(json_args='{"location": "')} + yield {1: DeltaToolCall(json_args='Paris"}')} + else: + # Second run - process tool results + yield '{"get_weather": "Tool result", "get_weather_parts": "Tool result"}' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + ) + adapter = _Adapter(agent=agent) + tool_call_id1 = uuid_str() + tool_call_id2 = uuid_str() + run_inputs = [ + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather and get_weather_parts for Paris', + ), + tools=[get_weather(), get_weather('get_weather_parts')], + ), + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather for Paris', + ), + AssistantMessage( + id='msg_2', + tool_calls=[ + ToolCall( + id=tool_call_id1, + type='function', + function=FunctionCall( + name='get_weather', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + content='Tool result', + tool_call_id=tool_call_id1, + ), + AssistantMessage( + id='msg_4', + tool_calls=[ + ToolCall( + id=tool_call_id2, + type='function', + function=FunctionCall( + name='get_weather_parts', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_5', + content='Tool result', + tool_call_id=tool_call_id2, + ), + tools=[get_weather(), get_weather('get_weather_parts')], + ), + ] + + events = await collect_events_from_adapter(adapter, *run_inputs) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'get_weather'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{"location": "Paris"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'get_weather_parts'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{"location": "'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': 'Paris"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + { + 'type': 'TEXT_MESSAGE_CONTENT', + 'messageId': IsStr(), + 'delta': '{"get_weather": "Tool result", "get_weather_parts": "Tool result"}', + }, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_tool_ag_ui_parts() -> None: + """Test AG-UI tool call with streaming/parts (same as tool_call_with_args_streaming).""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call with streaming args + yield {0: DeltaToolCall(name='get_weather')} + yield {0: DeltaToolCall(json_args='{"location":"')} + yield {0: DeltaToolCall(json_args='Paris"}')} + else: + # Second call - return text result + yield '{"get_weather": "Tool result"}' + + agent = Agent(model=FunctionModel(stream_function=stream_function)) + adapter = _Adapter(agent=agent) + run_inputs = [ + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather_parts for Paris', + ), + tools=[get_weather('get_weather_parts')], + ), + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather_parts for Paris', + ), + AssistantMessage( + id='msg_2', + tool_calls=[ + ToolCall( + id='pyd_ai_00000000000000000000000000000003', + type='function', + function=FunctionCall( + name='get_weather_parts', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + content='Tool result', + tool_call_id='pyd_ai_00000000000000000000000000000003', + ), + tools=[get_weather('get_weather_parts')], + ), + ] + events = await collect_events_from_adapter(adapter, *run_inputs) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'get_weather'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{"location":"'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': 'Paris"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': '{"get_weather": "Tool result"}'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': '{"get_weather": "Tool result"}'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_tool_local_single_event() -> None: + """Test local tool call that returns a single event.""" + + encoder = EventEncoder() + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call + yield {0: DeltaToolCall(name='send_snapshot')} + yield {0: DeltaToolCall(json_args='{}')} + else: + # Second call - return text result + yield encoder.encode(await send_snapshot()) + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[send_snapshot], + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Please call send_snapshot', + ), + ) + events = await collect_events_from_adapter(adapter, run_input) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'send_snapshot'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': 'msg_1', + 'toolCallId': IsStr(), + 'content': '{"type":"STATE_SNAPSHOT","timestamp":null,"raw_event":null,"snapshot":{"key":"value"}}', + 'role': 'tool', + }, + {'type': 'STATE_SNAPSHOT', 'snapshot': {'key': 'value'}}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': IsStr()}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_tool_local_multiple_events() -> None: + """Test local tool call that returns multiple events.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call + yield {0: DeltaToolCall(name='send_custom')} + yield {0: DeltaToolCall(json_args='{}')} + else: + # Second call - return text result + yield 'success send_custom called' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[send_custom], + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Please call send_custom', + ), + ) + events = await collect_events_from_adapter(adapter, run_input) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'send_custom'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': 'msg_1', + 'toolCallId': IsStr(), + 'content': IsStr(), + 'role': 'tool', + }, + {'type': 'CUSTOM', 'name': 'custom_event1', 'value': {'key1': 'value1'}}, + {'type': 'CUSTOM', 'name': 'custom_event2', 'value': {'key2': 'value2'}}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': 'success send_custom called'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_tool_local_parts() -> None: + """Test local tool call with streaming/parts.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First call - make a tool call with streaming args + yield {0: DeltaToolCall(name='current_time')} + yield {0: DeltaToolCall(json_args='{}')} + else: + # Second call - return text result + yield 'success current_time called' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[send_snapshot, send_custom, current_time], + ) + + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Please call current_time', + ), + ) + + events = await collect_events_from_adapter(adapter, run_input) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'current_time'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': 'msg_1', + 'toolCallId': IsStr(), + 'content': '2023-06-21T12:08:45.485981+00:00', + 'role': 'tool', + }, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': 'success current_time called'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_thinking() -> None: + """Test thinking events - now supported by FunctionModel.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaThinkingCalls | str]: + yield {0: DeltaThinkingPart(content='Thinking ')} + yield {0: DeltaThinkingPart(content='about the weather')} + yield 'Thought about the weather' + + agent = Agent( + model=FunctionModel(stream_function=stream_function), + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Think about the weather', + ), + ) + + events = await collect_events_from_adapter(adapter, run_input) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'THINKING_TEXT_MESSAGE_START'}, + {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'delta': 'Thinking '}, + {'type': 'THINKING_TEXT_MESSAGE_CONTENT', 'delta': 'about the weather'}, + {'type': 'THINKING_TEXT_MESSAGE_END'}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': 'Thought about the weather'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_tool_local_then_ag_ui() -> None: + """Test mixed local and AG-UI tool calls.""" + + async def stream_function( + messages: list[ModelMessage], agent_info: AgentInfo + ) -> AsyncIterator[DeltaToolCalls | str]: + if len(messages) == 1: + # First - call local tool (current_time) + yield {0: DeltaToolCall(name='current_time')} + yield {0: DeltaToolCall(json_args='{}')} + # Then - call AG-UI tool (get_weather) + yield {1: DeltaToolCall(name='get_weather')} + yield {1: DeltaToolCall(json_args='{"location": "Paris"}')} + else: + # Final response with results + yield 'current time is 2023-06-21T12:08:45.485981+00:00 and the weather in Paris is bright and sunny' + + tool_call_id1 = uuid_str() + tool_call_id2 = uuid_str() + agent = Agent( + model=FunctionModel(stream_function=stream_function), + tools=[current_time], + ) + adapter = _Adapter(agent=agent) + run_inputs = [ + create_input( + UserMessage( + id='msg_1', + content='Please tell me the time and then call get_weather for Paris', + ), + tools=[get_weather()], + ), + create_input( + UserMessage( + id='msg_1', + content='Please call get_weather for Paris', + ), + AssistantMessage( + id='msg_2', + tool_calls=[ + ToolCall( + id=tool_call_id1, + type='function', + function=FunctionCall( + name='current_time', + arguments='{}', + ), + ), + ], + ), + ToolMessage( + id='msg_3', + content='Tool result', + tool_call_id=tool_call_id1, + ), + AssistantMessage( + id='msg_4', + tool_calls=[ + ToolCall( + id=tool_call_id2, + type='function', + function=FunctionCall( + name='get_weather', + arguments='{"location": "Paris"}', + ), + ), + ], + ), + ToolMessage( + id='msg_5', + content='Bright and sunny', + tool_call_id=tool_call_id2, + ), + tools=[get_weather()], + ), + ] + events = await collect_events_from_adapter(adapter, *run_inputs) + + assert events == snapshot( + [ + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'current_time'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + {'type': 'TOOL_CALL_START', 'toolCallId': IsStr(), 'toolCallName': 'get_weather'}, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': IsStr(), 'delta': '{"location": "Paris"}'}, + {'type': 'TOOL_CALL_END', 'toolCallId': IsStr()}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': 'msg_1', + 'toolCallId': IsStr(), + 'content': '2023-06-21T12:08:45.485981+00:00', + 'role': 'tool', + }, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'RUN_STARTED', 'threadId': IsStr(), 'runId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + { + 'type': 'TEXT_MESSAGE_CONTENT', + 'messageId': IsStr(), + 'delta': 'current time is 2023-06-21T12:08:45.485981+00:00 and the weather in Paris is bright and sunny', + }, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': IsStr(), 'runId': IsStr()}, + ] + ) + + +async def test_request_with_state() -> None: + """Test request with state modification.""" + + agent: Agent[StateDeps[StateInt], str] = Agent( + model=FunctionModel(stream_function=simple_stream), + deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType] + ) + adapter = _Adapter(agent=agent) + run_inputs = [ + create_input( + UserMessage( + id='msg_1', + content='Hello, how are you?', + ), + state=StateInt(value=41), + ), + create_input( + UserMessage( + id='msg_2', + content='Hello, how are you?', + ), + ), + create_input( + UserMessage( + id='msg_3', + content='Hello, how are you?', + ), + state=StateInt(value=42), + ), + ] + + deps = StateDeps(StateInt()) + + last_value = deps.state.value + for run_input in run_inputs: + events = list[dict[str, Any]]() + async for event in adapter.run(run_input, deps=deps): + events.append(json.loads(event.removeprefix('data: '))) + + assert events == SIMPLE_RESULT + assert deps.state.value == run_input.state.value if run_input.state is not None else last_value + last_value = deps.state.value + + assert deps.state.value == 42 + + +async def test_concurrent_runs() -> None: + """Test concurrent execution of multiple runs.""" + import asyncio + + agent = Agent( + model=FunctionModel(stream_function=simple_stream), + ) + adapter = _Adapter(agent=agent) + concurrent_tasks: list[asyncio.Task[list[dict[str, Any]]]] = [] + + for i in range(5): # Test with 5 concurrent runs + run_input = create_input( + UserMessage( + id=f'msg_{i}', + content=f'Message {i}', + ), + thread_id=f'test_thread_{i}', + ) + + task = asyncio.create_task(collect_events_from_adapter(adapter, run_input)) + concurrent_tasks.append(task) + + results = await asyncio.gather(*concurrent_tasks) + + # Verify all runs completed successfully + for i, events in enumerate(results): + assert events == [ + {'type': 'RUN_STARTED', 'threadId': f'test_thread_{i}', 'runId': IsStr()}, + {'type': 'TEXT_MESSAGE_START', 'messageId': IsStr(), 'role': 'assistant'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': 'success '}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': IsStr(), 'delta': '(no tool calls)'}, + {'type': 'TEXT_MESSAGE_END', 'messageId': IsStr()}, + {'type': 'RUN_FINISHED', 'threadId': f'test_thread_{i}', 'runId': IsStr()}, + ] + + +@pytest.mark.anyio +async def test_to_ag_ui() -> None: + """Test the agent.to_ag_ui method.""" + + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + app = agent.to_ag_ui() + async with LifespanManager(app): + transport = httpx.ASGITransport(app) + async with httpx.AsyncClient(transport=transport) as client: + client.base_url = 'http://localhost:8000' + run_input = create_input( + UserMessage( + id='msg_1', + content='Hello, world!', + ), + ) + async with client.stream( + 'POST', + '/', + content=run_input.model_dump_json(), + headers={'Content-Type': 'application/json', 'Accept': SSE_CONTENT_TYPE}, + ) as response: + assert response.status_code == HTTPStatus.OK, f'Unexpected status code: {response.status_code}' + events: list[dict[str, Any]] = [] + async for line in response.aiter_lines(): + if line: + events.append(json.loads(line.removeprefix('data: '))) + + assert events == SIMPLE_RESULT diff --git a/uv.lock b/uv.lock index 28e85eaa67..a47bd1445a 100644 --- a/uv.lock +++ b/uv.lock @@ -27,6 +27,18 @@ members = [ "pydantic-graph", ] +[[package]] +name = "ag-ui-protocol" +version = "0.1.8" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "pydantic" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/27/de/0bddf7f26d5f38274c99401735c82ad59df9cead6de42f4bb2ad837286fe/ag_ui_protocol-0.1.8.tar.gz", hash = "sha256:eb745855e9fc30964c77e953890092f8bd7d4bbe6550d6413845428dd0faac0b", size = 5323, upload-time = "2025-07-15T10:55:36.389Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c4/00/40c6b0313c25d1ab6fac2ecba1cd5b15b1cd3c3a71b3d267ad890e405889/ag_ui_protocol-0.1.8-py3-none-any.whl", hash = "sha256:1567ccb067b7b8158035b941a985e7bb185172d660d4542f3f9c6fff77b55c6e", size = 7066, upload-time = "2025-07-15T10:55:35.075Z" }, +] + [[package]] name = "aiofiles" version = "23.2.1" @@ -2988,7 +3000,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, ] [package.optional-dependencies] @@ -3027,7 +3039,7 @@ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["a2a", "examples", "logfire"] @@ -3062,7 +3074,7 @@ dependencies = [ { name = "logfire", extra = ["asyncpg", "fastapi", "httpx", "sqlite3"] }, { name = "mcp", extra = ["cli"], marker = "python_full_version >= '3.10'" }, { name = "modal" }, - { name = "pydantic-ai-slim", extra = ["anthropic", "groq", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "groq", "openai", "vertexai"] }, { name = "pydantic-evals" }, { name = "python-multipart" }, { name = "rich" }, @@ -3078,7 +3090,7 @@ requires-dist = [ { name = "logfire", extras = ["asyncpg", "fastapi", "httpx", "sqlite3"], specifier = ">=2.6" }, { name = "mcp", extras = ["cli"], marker = "python_full_version >= '3.10'", specifier = ">=1.4.1" }, { name = "modal", specifier = ">=1.0.4" }, - { name = "pydantic-ai-slim", extras = ["anthropic", "groq", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "groq", "openai", "vertexai"], editable = "pydantic_ai_slim" }, { name = "pydantic-evals", editable = "pydantic_evals" }, { name = "python-multipart", specifier = ">=0.0.17" }, { name = "rich", specifier = ">=13.9.2" }, @@ -3103,6 +3115,10 @@ dependencies = [ a2a = [ { name = "fasta2a" }, ] +ag-ui = [ + { name = "ag-ui-protocol" }, + { name = "starlette" }, +] anthropic = [ { name = "anthropic" }, ] @@ -3175,6 +3191,7 @@ dev = [ [package.metadata] requires-dist = [ + { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.8" }, { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.37.24" }, @@ -3200,10 +3217,11 @@ requires-dist = [ { name = "pydantic-graph", editable = "pydantic_graph" }, { name = "requests", marker = "extra == 'vertexai'", specifier = ">=2.32.2" }, { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, + { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [ From 01c550c2001b3e5018c3dff6d4aa46e956eae0a8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 18 Jul 2025 11:14:56 -0600 Subject: [PATCH 38/89] chore: Fix inconsistent docs example output and fix variable name (#2248) --- docs/toolsets.md | 8 ++++++-- pydantic_ai_slim/pydantic_ai/result.py | 14 +++++++------- 2 files changed, 13 insertions(+), 9 deletions(-) diff --git a/docs/toolsets.md b/docs/toolsets.md index 066d54ef90..5caac22c09 100644 --- a/docs/toolsets.md +++ b/docs/toolsets.md @@ -341,6 +341,7 @@ It is is a no-op by default, but enables some useful abilities: You can subclass `WrapperToolset` to change the wrapped toolset's tool execution behavior by overriding the [`call_tool()`][pydantic_ai.toolsets.AbstractToolset.call_tool] method. ```python {title="logging_toolset.py" requires="function_toolset.py,combined_toolset.py,renamed_toolset.py,prepared_toolset.py"} +import asyncio from typing_extensions import Any from prepared_toolset import prepared_toolset @@ -356,6 +357,8 @@ class LoggingToolset(WrapperToolset): async def call_tool(self, name: str, tool_args: dict[str, Any], ctx: RunContext, tool: ToolsetTool) -> Any: LOG.append(f'Calling tool {name!r} with args: {tool_args!r}') try: + await asyncio.sleep(0.1 * len(LOG)) # (1)! + result = await super().call_tool(name, tool_args, ctx, tool) LOG.append(f'Finished calling tool {name!r} with result: {result!r}') except Exception as e: @@ -367,7 +370,7 @@ class LoggingToolset(WrapperToolset): logging_toolset = LoggingToolset(prepared_toolset) -agent = Agent(TestModel(), toolsets=[logging_toolset]) # (1)! +agent = Agent(TestModel(), toolsets=[logging_toolset]) # (2)! result = agent.run_sync('Call all the tools') print(LOG) """ @@ -384,7 +387,8 @@ print(LOG) """ ``` -1. We use [`TestModel`][pydantic_ai.models.test.TestModel] here as it will automatically call each tool. +1. All docs examples are tested in CI and their their output is verified, so we need `LOG` to always have the same order whenever this code is run. Since the tools could finish in any order, we sleep an increasing amount of time based on which number tool call we are to ensure that they finish (and log) in the same order they were called in. +2. We use [`TestModel`][pydantic_ai.models.test.TestModel] here as it will automatically call each tool. _(This example is complete, it can be run "as is")_ diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index 163189ac0b..d8439fb5d7 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -49,7 +49,7 @@ class AgentStream(Generic[AgentDepsT, OutputDataT]): _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _run_ctx: RunContext[AgentDepsT] _usage_limits: UsageLimits | None - _toolset: ToolManager[AgentDepsT] + _tool_manager: ToolManager[AgentDepsT] _agent_stream_iterator: AsyncIterator[AgentStreamEvent] | None = field(default=None, init=False) _final_result_event: FinalResultEvent | None = field(default=None, init=False) @@ -111,8 +111,8 @@ async def _validate_response( raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - return await self._toolset.handle_call(tool_call, allow_partial=allow_partial) - elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial) + elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( # pragma: no cover 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' @@ -154,7 +154,7 @@ def _get_final_result_event(e: _messages.ModelResponseStreamEvent) -> _messages. ): # pragma: no branch return _messages.FinalResultEvent(tool_name=None, tool_call_id=None) elif isinstance(new_part, _messages.ToolCallPart) and ( - tool_def := self._toolset.get_tool_def(new_part.tool_name) + tool_def := self._tool_manager.get_tool_def(new_part.tool_name) ): if tool_def.kind == 'output': return _messages.FinalResultEvent( @@ -196,7 +196,7 @@ class StreamedRunResult(Generic[AgentDepsT, OutputDataT]): _output_validators: list[OutputValidator[AgentDepsT, OutputDataT]] _output_tool_name: str | None _on_complete: Callable[[], Awaitable[None]] - _toolset: ToolManager[AgentDepsT] + _tool_manager: ToolManager[AgentDepsT] _initial_run_ctx_usage: Usage = field(init=False) is_complete: bool = field(default=False, init=False) @@ -443,8 +443,8 @@ async def validate_structured_output( raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool call for {self._output_tool_name!r}' ) - return await self._toolset.handle_call(tool_call, allow_partial=allow_partial) - elif deferred_tool_calls := self._toolset.get_deferred_tool_calls(message.parts): + return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial) + elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( 'A deferred tool call was present, but `DeferredToolCalls` is not among output types. To resolve this, add `DeferredToolCalls` to the list of output types for this agent.' From 7d50564d2f0e5e44f39917dc26e82e2a9045fa0a Mon Sep 17 00:00:00 2001 From: Marcelo Trylesinski Date: Mon, 21 Jul 2025 12:15:24 +0200 Subject: [PATCH 39/89] Add `async with self` in `agent_to_a2a` (#2266) --- pydantic_ai_slim/pydantic_ai/_a2a.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_a2a.py b/pydantic_ai_slim/pydantic_ai/_a2a.py index 1f033994ef..8a916b5cc7 100644 --- a/pydantic_ai_slim/pydantic_ai/_a2a.py +++ b/pydantic_ai_slim/pydantic_ai/_a2a.py @@ -59,12 +59,12 @@ @asynccontextmanager -async def worker_lifespan(app: FastA2A, worker: Worker) -> AsyncIterator[None]: +async def worker_lifespan(app: FastA2A, worker: Worker, agent: Agent[AgentDepsT, OutputDataT]) -> AsyncIterator[None]: """Custom lifespan that runs the worker during application startup. This ensures the worker is started and ready to process tasks as soon as the application starts. """ - async with app.task_manager: + async with app.task_manager, agent: async with worker.run(): yield @@ -93,7 +93,7 @@ def agent_to_a2a( broker = broker or InMemoryBroker() worker = AgentWorker(agent=agent, broker=broker, storage=storage) - lifespan = lifespan or partial(worker_lifespan, worker=worker) + lifespan = lifespan or partial(worker_lifespan, worker=worker, agent=agent) return FastA2A( storage=storage, From 9ca4bca64753bb077166e79b6103ea6fcda3752b Mon Sep 17 00:00:00 2001 From: Aditya Vardhan <76904033+adtyavrdhn@users.noreply.github.com> Date: Mon, 21 Jul 2025 19:55:23 +0530 Subject: [PATCH 40/89] Fix include_content not working as expected (#2206) --- pydantic_ai_slim/pydantic_ai/agent.py | 17 +++++++++-------- .../pydantic_ai/models/instrumented.py | 7 ++++++- tests/test_logfire.py | 3 --- 3 files changed, 15 insertions(+), 12 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index b2e1667ddc..9a68bc6775 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -843,14 +843,15 @@ async def get_instructions(run_context: RunContext[AgentDepsT]) -> str | None: agent_run = AgentRun(graph_run) yield agent_run if (final_result := agent_run.result) is not None and run_span.is_recording(): - run_span.set_attribute( - 'final_result', - ( - final_result.output - if isinstance(final_result.output, str) - else json.dumps(InstrumentedModel.serialize_any(final_result.output)) - ), - ) + if instrumentation_settings and instrumentation_settings.include_content: + run_span.set_attribute( + 'final_result', + ( + final_result.output + if isinstance(final_result.output, str) + else json.dumps(InstrumentedModel.serialize_any(final_result.output)) + ), + ) finally: try: if instrumentation_settings and run_span.is_recording(): diff --git a/pydantic_ai_slim/pydantic_ai/models/instrumented.py b/pydantic_ai_slim/pydantic_ai/models/instrumented.py index f40340998b..233020f6f5 100644 --- a/pydantic_ai_slim/pydantic_ai/models/instrumented.py +++ b/pydantic_ai_slim/pydantic_ai/models/instrumented.py @@ -156,7 +156,12 @@ def messages_to_otel_events(self, messages: list[ModelMessage]) -> list[Event]: events: list[Event] = [] instructions = InstrumentedModel._get_instructions(messages) # pyright: ignore [reportPrivateUsage] if instructions is not None: - events.append(Event('gen_ai.system.message', body={'content': instructions, 'role': 'system'})) + events.append( + Event( + 'gen_ai.system.message', + body={**({'content': instructions} if self.include_content else {}), 'role': 'system'}, + ) + ) for message_index, message in enumerate(messages): message_events: list[Event] = [] diff --git a/tests/test_logfire.py b/tests/test_logfire.py index 799724179a..d12e028fa3 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -448,7 +448,6 @@ class MyOutput: snapshot( [ { - 'content': 'Here are some instructions', 'role': 'system', 'event.name': 'gen_ai.system.message', }, @@ -480,7 +479,6 @@ class MyOutput: ] ) ), - 'final_result': '{"content": "a"}', 'logfire.json_schema': IsJson( snapshot( { @@ -497,7 +495,6 @@ class MyOutput: snapshot( [ { - 'content': 'Here are some instructions', 'role': 'system', 'gen_ai.system': 'test', 'event.name': 'gen_ai.system.message', From 09bd7dd0b4cc20015fd4f01514b987e876b412c6 Mon Sep 17 00:00:00 2001 From: Daniel <38250010+Kigstn@users.noreply.github.com> Date: Mon, 21 Jul 2025 19:32:36 +0200 Subject: [PATCH 41/89] Support streamable HTTP in mcp-run-python (#2230) Co-authored-by: daniel.jaekel --- docs/mcp/run-python.md | 6 +- mcp-run-python/deno.json | 2 +- mcp-run-python/deno.lock | 45 ++++---- mcp-run-python/src/main.ts | 162 ++++++++++++++++++++++++++--- mcp-run-python/test_mcp_servers.py | 28 ++++- 5 files changed, 206 insertions(+), 37 deletions(-) diff --git a/docs/mcp/run-python.md b/docs/mcp/run-python.md index f99a159827..50ac9cfd5f 100644 --- a/docs/mcp/run-python.md +++ b/docs/mcp/run-python.md @@ -21,7 +21,7 @@ The MCP Run Python server is distributed as a [JSR package](https://jsr.io/@pyda ```bash {title="terminal"} deno run \ -N -R=node_modules -W=node_modules --node-modules-dir=auto \ - jsr:@pydantic/mcp-run-python [stdio|sse|warmup] + jsr:@pydantic/mcp-run-python [stdio|streamable_http|sse|warmup] ``` where: @@ -34,6 +34,10 @@ where: - `stdio` runs the server with the [Stdio MCP transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#stdio) — suitable for running the process as a subprocess locally +- `streamable_http` runs the server with the + [Streamable HTTP MCP transport](https://modelcontextprotocol.io/specification/2025-03-26/basic/transports#streamable-http) + — running the server as an HTTP server to connect locally or remotely. + This supports stateful requests, but does not require the client to hold a stateful connection like SSE - `sse` runs the server with the [SSE MCP transport](https://spec.modelcontextprotocol.io/specification/2024-11-05/basic/transports/#http-with-sse) — running the server as an HTTP server to connect locally or remotely diff --git a/mcp-run-python/deno.json b/mcp-run-python/deno.json index 0f01643123..b84e0546c4 100644 --- a/mcp-run-python/deno.json +++ b/mcp-run-python/deno.json @@ -13,7 +13,7 @@ "build-publish": "deno task build && deno publish" }, "imports": { - "@modelcontextprotocol/sdk": "npm:@modelcontextprotocol/sdk@^1.8.0", + "@modelcontextprotocol/sdk": "npm:@modelcontextprotocol/sdk@^1.15.1", "@std/cli": "jsr:@std/cli@^1.0.15", "@std/path": "jsr:@std/path@^1.0.8", // do NOT upgrade above this version until there is a workaround for https://github.com/pyodide/pyodide/pull/5621 diff --git a/mcp-run-python/deno.lock b/mcp-run-python/deno.lock index 46bb54872a..0bd8680730 100644 --- a/mcp-run-python/deno.lock +++ b/mcp-run-python/deno.lock @@ -5,7 +5,7 @@ "jsr:@std/cli@^1.0.15": "1.0.15", "jsr:@std/path@*": "1.0.8", "jsr:@std/path@^1.0.8": "1.0.8", - "npm:@modelcontextprotocol/sdk@^1.8.0": "1.8.0_express@5.1.0_zod@3.24.2", + "npm:@modelcontextprotocol/sdk@^1.15.1": "1.15.1_express@5.1.0_zod@3.24.2", "npm:@types/node@*": "22.12.0", "npm:@types/node@22.12.0": "22.12.0", "npm:eslint@*": "9.23.0", @@ -94,13 +94,15 @@ "@humanwhocodes/retry@0.4.2": { "integrity": "sha512-xeO57FpIu4p1Ri3Jq/EXq4ClRm86dVF2z/+kvFnyqVYRavTZmaFaUBbWCOuuTh0o/g7DSsk6kc2vrS4Vl5oPOQ==" }, - "@modelcontextprotocol/sdk@1.8.0_express@5.1.0_zod@3.24.2": { - "integrity": "sha512-e06W7SwrontJDHwCawNO5SGxG+nU9AAx+jpHHZqGl/WrDBdWOpvirC+s58VpJTB5QemI4jTRcjWT4Pt3Q1NPQQ==", + "@modelcontextprotocol/sdk@1.15.1_express@5.1.0_zod@3.24.2": { + "integrity": "sha512-W/XlN9c528yYn+9MQkVjxiTPgPxoxt+oczfjHBDsJx0+59+O7B75Zhsp0B16Xbwbz8ANISDajh6+V7nIcPMc5w==", "dependencies": [ + "ajv", "content-type", "cors", "cross-spawn", "eventsource", + "eventsource-parser", "express", "express-rate-limit", "pkce-challenge", @@ -231,8 +233,8 @@ "cookie-signature@1.2.2": { "integrity": "sha512-D76uU73ulSXrD1UXF4KE2TMxVVwhsnCgfAyTg9k8P6KGZjlXKrOLe4dJQKI3Bxi5wjesZoFXJWElNWBjPZMbhg==" }, - "cookie@0.7.1": { - "integrity": "sha512-6DnInpx7SJ2AK3+CTUE/ZM0vWTUboZCegxhC2xiIydHR9jNuTAASBrfEpHhiGOZw/nX51bHt6YQl8jsGo4y/0w==" + "cookie@0.7.2": { + "integrity": "sha512-yki5XnKuf750l50uGTllt6kKILY4nQ1eNIQatoXEByZ5dWgnKqbnqmTrBE5B4N7lrMJKQ2ytWMiTO2o0v6Ew/w==" }, "cors@2.8.5": { "integrity": "sha512-KIHbLJqu73RGr/hnbrO9uBeixNGuvSQjul/jdFvS/KFSIH1hWVd1ng7zOHx+YrEfInLG7q4n6GHQ9cDtxv/P6g==", @@ -376,17 +378,17 @@ "etag@1.8.1": { "integrity": "sha512-aIL5Fx7mawVa300al2BnEE4iNvo1qETxLrPI/o05L7z6go7fCw1J6EQmbK4FmJ2AS7kgVF/KEZWufBfdClMcPg==" }, - "eventsource-parser@3.0.1": { - "integrity": "sha512-VARTJ9CYeuQYb0pZEPbzi740OWFgpHe7AYJ2WFZVnUDUQp5Dk2yJUgF36YsZ81cOyxT0QxmXD2EQpapAouzWVA==" + "eventsource-parser@3.0.3": { + "integrity": "sha512-nVpZkTMM9rF6AQ9gPJpFsNAMt48wIzB5TQgiTLdHiuO8XEDhUgZEhqKlZWXbIzo9VmJ/HvysHqEaVeD5v9TPvA==" }, - "eventsource@3.0.6": { - "integrity": "sha512-l19WpE2m9hSuyP06+FbuUUf1G+R0SFLrtQfbRb9PRr+oimOfxQhgGCbVaXg5IvZyyTThJsxh6L/srkMiCeBPDA==", + "eventsource@3.0.7": { + "integrity": "sha512-CRT1WTyuQoD771GW56XEZFQ/ZoSfWid1alKGDYMmkt2yl8UXrVR4pspqWNEcqKvVIzg6PAltWjxcSSPrboA4iA==", "dependencies": [ "eventsource-parser" ] }, - "express-rate-limit@7.5.0_express@5.1.0": { - "integrity": "sha512-eB5zbQh5h+VenMPM3fh+nw1YExi5nMr6HUCR62ELSP11huvxm/Uir1H1QEyTkk5QX6A58pX6NmaTMceKZ0Eodg==", + "express-rate-limit@7.5.1_express@5.1.0": { + "integrity": "sha512-7iN8iPMDzOMHPUYllBEsQdWVB6fPDMPqwjBaFrgr4Jgr/+okjvzAy+UHlYYL/Vs0OsOrMkwS6PJDkFlJwoxUnw==", "dependencies": [ "express" ] @@ -418,7 +420,7 @@ "router", "send", "serve-static", - "statuses", + "statuses@2.0.2", "type-is", "vary" ] @@ -446,7 +448,7 @@ "escape-html", "on-finished", "parseurl", - "statuses" + "statuses@2.0.2" ] }, "find-up@5.0.0": { @@ -527,7 +529,7 @@ "depd", "inherits", "setprototypeof", - "statuses", + "statuses@2.0.1", "toidentifier" ] }, @@ -701,8 +703,8 @@ "path-to-regexp@8.2.0": { "integrity": "sha512-TdrF7fW9Rphjq4RjrW0Kp2AW0Ahwu9sRGTkS6bvDi0SCwZlEZYmcfDbEsTz8RVk0EHIS/Vd1bv3JhG+1xZuAyQ==" }, - "pkce-challenge@4.1.0": { - "integrity": "sha512-ZBmhE1C9LcPoH9XZSdwiPtbPHZROwAnMy+kIFQVrnMCxY4Cudlz3gBOpzilgc0jOgRaiT3sIWfpMomW2ar2orQ==" + "pkce-challenge@5.0.0": { + "integrity": "sha512-ueGLflrrnvwB3xuo/uGob5pd5FN7l0MsLf0Z87o/UQmRtwjvfylfc9MurIxRAWywCYTgrvpXBcqjV4OfCYGCIQ==" }, "prelude-ls@1.2.1": { "integrity": "sha512-vkcDPrRZo1QZLbn5RLGPpg/WmIQ65qoWWhcGKf/b5eplkkarX0m9z8ppCat4mlOqUsWpyNuYgO3VRyrYHSzX5g==" @@ -773,7 +775,7 @@ "ms", "on-finished", "range-parser", - "statuses" + "statuses@2.0.2" ] }, "serve-static@2.2.0": { @@ -836,6 +838,9 @@ "statuses@2.0.1": { "integrity": "sha512-RwNA9Z/7PrK06rYLIzFMlaF+l73iwpzsqRIFgbMLbTcLD6cOao82TaWefPXQvB2fOC4AjuYSEndS7N/mTCbkdQ==" }, + "statuses@2.0.2": { + "integrity": "sha512-DvEy55V3DB7uknRo+4iOGT5fP1slR8wQohVdknigZPMpMstaKJQWhwiYBACJE3Ul2pTnATihhBYnRhZQHGBiRw==" + }, "strip-json-comments@3.1.1": { "integrity": "sha512-6fPc+R4ihwqP6N/aIv2f1gMH8lOVtWQHoqC4yK6oSDVVocumAsfCqjkXnqiYMhmMwS/mEHLp7Vehlt3ql6lEig==" }, @@ -896,8 +901,8 @@ "yocto-queue@0.1.0": { "integrity": "sha512-rVksvsnNCdJ/ohGc6xgPwyN8eheCxsiLM8mxuE/t/mOVqJewPuO1miLpTHQiRgTKCLexL4MeAFVagts7HmNZ2Q==" }, - "zod-to-json-schema@3.24.5_zod@3.24.2": { - "integrity": "sha512-/AuWwMP+YqiPbsJx5D6TfgRTc4kTLjsh5SOcd4bLsfUg2RcEXrFMJl1DGgdHy2aCfsIA/cr/1JM0xcB2GZji8g==", + "zod-to-json-schema@3.24.6_zod@3.24.2": { + "integrity": "sha512-h/z3PKvcTcTetyjl1fkj79MHNEjm+HpD6NXheWjzOekY7kV+lwDYnHw+ivHkijnCSMz1yJaWBD9vu/Fcmk+vEg==", "dependencies": [ "zod" ] @@ -910,7 +915,7 @@ "dependencies": [ "jsr:@std/cli@^1.0.15", "jsr:@std/path@^1.0.8", - "npm:@modelcontextprotocol/sdk@^1.8.0", + "npm:@modelcontextprotocol/sdk@^1.15.1", "npm:pyodide@0.27.6", "npm:zod@^3.24.2" ] diff --git a/mcp-run-python/src/main.ts b/mcp-run-python/src/main.ts index 6eb051f93f..caf1cb9896 100644 --- a/mcp-run-python/src/main.ts +++ b/mcp-run-python/src/main.ts @@ -2,14 +2,18 @@ import './polyfill.ts' import http from 'node:http' +import { randomUUID } from 'node:crypto' import { parseArgs } from '@std/cli/parse-args' import { SSEServerTransport } from '@modelcontextprotocol/sdk/server/sse.js' import { StdioServerTransport } from '@modelcontextprotocol/sdk/server/stdio.js' +import { StreamableHTTPServerTransport } from '@modelcontextprotocol/sdk/server/streamableHttp.js' +import { isInitializeRequest } from '@modelcontextprotocol/sdk/types.js' import { type LoggingLevel, SetLevelRequestSchema } from '@modelcontextprotocol/sdk/types.js' import { McpServer } from '@modelcontextprotocol/sdk/server/mcp.js' import { z } from 'zod' import { asXml, runCode } from './runCode.ts' +import { Buffer } from 'node:buffer' const VERSION = '0.0.13' @@ -17,6 +21,13 @@ export async function main() { const { args } = Deno if (args.length === 1 && args[0] === 'stdio') { await runStdio() + } else if (args.length >= 1 && args[0] === 'streamable_http') { + const flags = parseArgs(Deno.args, { + string: ['port'], + default: { port: '3001' }, + }) + const port = parseInt(flags.port) + runStreamableHttp(port) } else if (args.length >= 1 && args[0] === 'sse') { const flags = parseArgs(Deno.args, { string: ['port'], @@ -31,7 +42,7 @@ export async function main() { `\ Invalid arguments. -Usage: deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python [stdio|sse|warmup] +Usage: deno run -N -R=node_modules -W=node_modules --node-modules-dir=auto jsr:@pydantic/mcp-run-python [stdio|streamable_http|sse|warmup] options: --port Port to run the SSE server on (default: 3001)`, @@ -103,6 +114,138 @@ print('python code here') return server } +/* + * Define some QOL functions for both the SSE and Streamable HTTP server implementation + */ +function httpGetUrl(req: http.IncomingMessage): URL { + return new URL( + req.url ?? '', + `http://${req.headers.host ?? 'unknown'}`, + ) +} + +function httpGetBody(req: http.IncomingMessage): Promise { + // https://nodejs.org/en/learn/modules/anatomy-of-an-http-transaction#request-body + return new Promise((resolve) => { + // deno-lint-ignore no-explicit-any + const bodyParts: any[] = [] + let body + req.on('data', (chunk) => { + bodyParts.push(chunk) + }).on('end', () => { + body = Buffer.concat(bodyParts).toString() + resolve(JSON.parse(body)) + }) + }) +} + +function httpSetTextResponse(res: http.ServerResponse, status: number, text: string) { + res.setHeader('Content-Type', 'text/plain') + res.statusCode = status + res.end(`${text}\n`) +} + +function httpSetJsonResponse(res: http.ServerResponse, status: number, text: string, code: number) { + res.setHeader('Content-Type', 'application/json') + res.statusCode = status + res.write(JSON.stringify({ + jsonrpc: '2.0', + error: { + code: code, + message: text, + }, + id: null, + })) + res.end() +} + +/* + * Run the MCP server using the Streamable HTTP transport + */ +function runStreamableHttp(port: number) { + // https://github.com/modelcontextprotocol/typescript-sdk?tab=readme-ov-file#with-session-management + const mcpServer = createServer() + const transports: { [sessionId: string]: StreamableHTTPServerTransport } = {} + + const server = http.createServer(async (req, res) => { + const url = httpGetUrl(req) + let pathMatch = false + function match(method: string, path: string): boolean { + if (url.pathname === path) { + pathMatch = true + return req.method === method + } + return false + } + + // Reusable handler for GET and DELETE requests + async function handleSessionRequest() { + const sessionId = req.headers['mcp-session-id'] as string | undefined + if (!sessionId || !transports[sessionId]) { + httpSetTextResponse(res, 400, 'Invalid or missing session ID') + return + } + + const transport = transports[sessionId] + await transport.handleRequest(req, res) + } + + // Handle different request methods and paths + if (match('POST', '/mcp')) { + // Check for existing session ID + const sessionId = req.headers['mcp-session-id'] as string | undefined + let transport: StreamableHTTPServerTransport + + const body = await httpGetBody(req) + + if (sessionId && transports[sessionId]) { + // Reuse existing transport + transport = transports[sessionId] + } else if (!sessionId && isInitializeRequest(body)) { + // New initialization request + transport = new StreamableHTTPServerTransport({ + sessionIdGenerator: () => randomUUID(), + onsessioninitialized: (sessionId) => { + // Store the transport by session ID + transports[sessionId] = transport + }, + }) + + // Clean up transport when closed + transport.onclose = () => { + if (transport.sessionId) { + delete transports[transport.sessionId] + } + } + + await mcpServer.connect(transport) + } else { + httpSetJsonResponse(res, 400, 'Bad Request: No valid session ID provided', -32000) + return + } + + // Handle the request + await transport.handleRequest(req, res, body) + } else if (match('GET', '/mcp')) { + // Handle server-to-client notifications via SSE + await handleSessionRequest() + } else if (match('DELETE', '/mcp')) { + // Handle requests for session termination + await handleSessionRequest() + } else if (pathMatch) { + httpSetTextResponse(res, 405, 'Method not allowed') + } else { + httpSetTextResponse(res, 404, 'Page not found') + } + }) + + server.listen(port, () => { + console.log( + `Running MCP Run Python version ${VERSION} with Streamable HTTP transport on port ${port}`, + ) + }) +} + /* * Run the MCP server using the SSE transport, e.g. over HTTP. */ @@ -111,10 +254,7 @@ function runSse(port: number) { const transports: { [sessionId: string]: SSEServerTransport } = {} const server = http.createServer(async (req, res) => { - const url = new URL( - req.url ?? '', - `http://${req.headers.host ?? 'unknown'}`, - ) + const url = httpGetUrl(req) let pathMatch = false function match(method: string, path: string): boolean { if (url.pathname === path) { @@ -123,12 +263,6 @@ function runSse(port: number) { } return false } - function textResponse(status: number, text: string) { - res.setHeader('Content-Type', 'text/plain') - res.statusCode = status - res.end(`${text}\n`) - } - // console.log(`${req.method} ${url}`) if (match('GET', '/sse')) { const transport = new SSEServerTransport('/messages', res) @@ -143,12 +277,12 @@ function runSse(port: number) { if (transport) { await transport.handlePostMessage(req, res) } else { - textResponse(400, `No transport found for sessionId '${sessionId}'`) + httpSetTextResponse(res, 400, `No transport found for sessionId '${sessionId}'`) } } else if (pathMatch) { - textResponse(405, 'Method not allowed') + httpSetTextResponse(res, 405, 'Method not allowed') } else { - textResponse(404, 'Page not found') + httpSetTextResponse(res, 404, 'Page not found') } }) diff --git a/mcp-run-python/test_mcp_servers.py b/mcp-run-python/test_mcp_servers.py index 23ca98380e..3fd72927f1 100644 --- a/mcp-run-python/test_mcp_servers.py +++ b/mcp-run-python/test_mcp_servers.py @@ -13,6 +13,7 @@ from mcp import ClientSession, StdioServerParameters, types from mcp.client.sse import sse_client from mcp.client.stdio import stdio_client +from mcp.client.streamable_http import streamablehttp_client if TYPE_CHECKING: from mcp import ClientSession @@ -33,13 +34,38 @@ def anyio_backend(): return 'asyncio' -@pytest.fixture(name='mcp_session', params=['stdio', 'sse']) +@pytest.fixture(name='mcp_session', params=['stdio', 'sse', 'streamable_http']) async def fixture_mcp_session(request: pytest.FixtureRequest) -> AsyncIterator[ClientSession]: if request.param == 'stdio': server_params = StdioServerParameters(command='deno', args=[*DENO_ARGS, 'stdio']) async with stdio_client(server_params) as (read, write): async with ClientSession(read, write) as session: yield session + elif request.param == 'streamable_http': + port = 3101 + p = subprocess.Popen(['deno', *DENO_ARGS, 'streamable_http', f'--port={port}']) + try: + url = f'http://localhost:{port}/mcp' + + async with AsyncClient() as client: + for _ in range(10): + try: + await client.get(url, timeout=0.01) + except HTTPError: + await asyncio.sleep(0.1) + else: + break + + async with streamablehttp_client(url) as (read_stream, write_stream, _): + async with ClientSession(read_stream, write_stream) as session: + yield session + + finally: + p.terminate() + exit_code = p.wait() + if exit_code > 0: + pytest.fail(f'Process exited with code {exit_code}') + else: port = 3101 From 091c499d2febe8aa49cfc3c4c77c8efa7266ea7a Mon Sep 17 00:00:00 2001 From: Samuel Colvin Date: Mon, 21 Jul 2025 12:17:58 -0700 Subject: [PATCH 42/89] change `format_as_xml` defaults (#2228) --- pydantic_ai_slim/pydantic_ai/format_prompt.py | 9 ++-- tests/test_examples.py | 2 +- tests/test_format_as_xml.py | 48 +++++++++---------- 3 files changed, 27 insertions(+), 32 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/format_prompt.py b/pydantic_ai_slim/pydantic_ai/format_prompt.py index 311f4277d9..34f06a40a8 100644 --- a/pydantic_ai_slim/pydantic_ai/format_prompt.py +++ b/pydantic_ai_slim/pydantic_ai/format_prompt.py @@ -13,9 +13,8 @@ def format_as_xml( obj: Any, - root_tag: str = 'examples', - item_tag: str = 'example', - include_root_tag: bool = True, + root_tag: str | None = None, + item_tag: str = 'item', none_str: str = 'null', indent: str | None = ' ', ) -> str: @@ -32,8 +31,6 @@ def format_as_xml( root_tag: Outer tag to wrap the XML in, use `None` to omit the outer tag. item_tag: Tag to use for each item in an iterable (e.g. list), this is overridden by the class name for dataclasses and Pydantic models. - include_root_tag: Whether to include the root tag in the output - (The root tag is always included if it includes a body - e.g. when the input is a simple value). none_str: String to use for `None` values. indent: Indentation string to use for pretty printing. @@ -55,7 +52,7 @@ def format_as_xml( ``` """ el = _ToXml(item_tag=item_tag, none_str=none_str).to_xml(obj, root_tag) - if not include_root_tag and el.text is None: + if root_tag is None and el.text is None: join = '' if indent is None else '\n' return join.join(_rootless_xml_elements(el, indent)) else: diff --git a/tests/test_examples.py b/tests/test_examples.py index 0fbe64bc3a..4b6bc27bcd 100644 --- a/tests/test_examples.py +++ b/tests/test_examples.py @@ -502,7 +502,7 @@ async def model_logic( # noqa: C901 ) elif m.content.startswith('Write a list of 5 very rude things that I might say'): raise UnexpectedModelBehavior('Safety settings triggered', body='') - elif m.content.startswith('\n '): + elif m.content.startswith('\n John Doe'): return ModelResponse( parts=[ToolCallPart(tool_name='final_result_EmailOk', args={}, tool_call_id='pyd_ai_tool_call_id')] ) diff --git a/tests/test_format_as_xml.py b/tests/test_format_as_xml.py index 0781164aab..37053a67f9 100644 --- a/tests/test_format_as_xml.py +++ b/tests/test_format_as_xml.py @@ -123,35 +123,35 @@ class ExamplePydanticModel(BaseModel): ), ], ) -def test(input_obj: Any, output: str): - assert format_as_xml(input_obj) == output +def test_root_tag(input_obj: Any, output: str): + assert format_as_xml(input_obj, root_tag='examples', item_tag='example') == output @pytest.mark.parametrize( 'input_obj,output', [ - pytest.param('a string', snapshot('a string'), id='string'), - pytest.param('a foo', snapshot('a <ex>foo</ex>'), id='string'), - pytest.param(42, snapshot('42'), id='int'), + pytest.param('a string', snapshot('a string'), id='string'), + pytest.param('a foo', snapshot('a <ex>foo</ex>'), id='string'), + pytest.param(42, snapshot('42'), id='int'), pytest.param( [1, 2, 3], snapshot("""\ -1 -2 -3\ +1 +2 +3\ """), id='list[int]', ), pytest.param( [[1, 2], [3]], snapshot("""\ - - 1 - 2 - - - 3 -\ + + 1 + 2 + + + 3 +\ """), id='list[list[int]]', ), @@ -166,24 +166,22 @@ def test(input_obj: Any, output: str): pytest.param( [datetime(2025, 1, 1, 12, 13), date(2025, 1, 2)], snapshot("""\ -2025-01-01T12:13:00 -2025-01-02\ +2025-01-01T12:13:00 +2025-01-02\ """), id='list[date]', ), ], ) def test_no_root(input_obj: Any, output: str): - assert format_as_xml(input_obj, include_root_tag=False) == output + assert format_as_xml(input_obj) == output def test_no_indent(): - assert format_as_xml([1, 2, 3], indent=None) == snapshot( - '123' - ) - assert format_as_xml([1, 2, 3], indent=None, include_root_tag=False) == snapshot( - '123' + assert format_as_xml([1, 2, 3], indent=None, root_tag='example') == snapshot( + '123' ) + assert format_as_xml([1, 2, 3], indent=None) == snapshot('123') def test_invalid_value(): @@ -197,8 +195,8 @@ def test_invalid_key(): def test_set(): - assert '1' in format_as_xml({1, 2, 3}) + assert '1' in format_as_xml({1, 2, 3}, item_tag='example') def test_custom_null(): - assert format_as_xml(None, none_str='nil') == snapshot('nil') + assert format_as_xml(None, none_str='nil') == snapshot('nil') From 87871b3f73411204a34e50feadb4b2bcc38414fd Mon Sep 17 00:00:00 2001 From: Aditya Vardhan <76904033+adtyavrdhn@users.noreply.github.com> Date: Tue, 22 Jul 2025 04:22:46 +0530 Subject: [PATCH 43/89] Fix LLMJudge input handling to preserve BinaryContent as separate message part instead of stringifying (#2173) Co-authored-by: Douwe Maan --- .../evaluators/llm_as_a_judge.py | 97 ++++----- tests/evals/test_llm_as_a_judge.py | 196 +++++++++++++++++- 2 files changed, 234 insertions(+), 59 deletions(-) diff --git a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py index 99eee601cb..0f1b0ac101 100644 --- a/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py +++ b/pydantic_evals/pydantic_evals/evaluators/llm_as_a_judge.py @@ -1,5 +1,6 @@ from __future__ import annotations +from collections.abc import Sequence from textwrap import dedent from typing import Any @@ -7,6 +8,7 @@ from pydantic_core import to_json from pydantic_ai import Agent, models +from pydantic_ai.messages import MultiModalContentTypes, UserContent from pydantic_ai.settings import ModelSettings __all__ = ( @@ -62,16 +64,7 @@ async def judge_output( If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ - user_prompt = dedent( - f""" - - {_stringify(output)} - - - {rubric} - - """ - ) + user_prompt = _build_prompt(output=output, rubric=rubric) return ( await _judge_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings) ).output @@ -112,19 +105,8 @@ async def judge_input_output( If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ - user_prompt = dedent( - f""" - - - {_stringify(output)} - - - {rubric} - - """ - ) + user_prompt = _build_prompt(inputs=inputs, output=output, rubric=rubric) + return ( await _judge_input_output_agent.run(user_prompt, model=model or _default_model, model_settings=model_settings) ).output @@ -168,22 +150,7 @@ async def judge_input_output_expected( If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ - user_prompt = dedent( - f""" - - - {_stringify(expected_output)} - - - {_stringify(output)} - - - {rubric} - - """ - ) + user_prompt = _build_prompt(inputs=inputs, output=output, rubric=rubric, expected_output=expected_output) return ( await _judge_input_output_expected_agent.run( @@ -227,19 +194,7 @@ async def judge_output_expected( If the model is not specified, a default model is used. The default model starts as 'openai:gpt-4o', but this can be changed using the `set_default_judge_model` function. """ - user_prompt = dedent( - f""" - - {_stringify(expected_output)} - - - {_stringify(output)} - - - {rubric} - - """ - ) + user_prompt = _build_prompt(output=output, rubric=rubric, expected_output=expected_output) return ( await _judge_output_expected_agent.run( user_prompt, model=model or _default_model, model_settings=model_settings @@ -265,3 +220,41 @@ def _stringify(value: Any) -> str: return to_json(value).decode() except Exception: return repr(value) + + +def _build_prompt( + output: Any, + rubric: str, + inputs: Any | None = None, + expected_output: Any | None = None, +) -> str | Sequence[str | UserContent]: + """Build a prompt that includes input, output, and rubric.""" + sections: list[str | UserContent] = [] + + if inputs is not None: + if isinstance(inputs, str): + sections.append(f'') + else: + sections.append('') + + sections.append(f'\n{_stringify(output)}\n') + sections.append(f'\n{rubric}\n') + + if expected_output is not None: + sections.append(f'\n{_stringify(expected_output)}\n') + + if inputs is None or isinstance(inputs, str): + return '\n\n'.join(sections) # type: ignore[arg-type] + else: + return sections diff --git a/tests/evals/test_llm_as_a_judge.py b/tests/evals/test_llm_as_a_judge.py index 4e18c5b13d..404c1f81a8 100644 --- a/tests/evals/test_llm_as_a_judge.py +++ b/tests/evals/test_llm_as_a_judge.py @@ -1,9 +1,10 @@ from __future__ import annotations as _annotations import pytest +from inline_snapshot import snapshot from pytest_mock import MockerFixture -from ..conftest import try_import +from ..conftest import BinaryContent, try_import with try_import() as imports_successful: from pydantic_ai.settings import ModelSettings @@ -141,6 +142,54 @@ async def test_judge_input_output_mock(mocker: MockerFixture): assert '\nOutput contains input\n' in call_args[0] +async def test_judge_input_output_binary_content_list_mock(mocker: MockerFixture, image_content: BinaryContent): + """Test judge_input_output function with mocked agent.""" + # Mock the agent run method + mock_result = mocker.MagicMock() + mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) + mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + + result = await judge_input_output([image_content, image_content], 'Hello world', 'Output contains input') + assert isinstance(result, GradingOutput) + assert result.reason == 'Test passed' + assert result.pass_ is True + assert result.score == 1.0 + + # Verify the agent was called with correct prompt + mock_run.assert_called_once() + raw_prompt = mock_run.call_args[0][0] + + # 1) It must be a list + assert isinstance(raw_prompt, list), 'Expected prompt to be a list when passing binary' + + # 2) The BinaryContent you passed in should be one of the elements + assert image_content in raw_prompt, 'Expected the exact BinaryContent instance to be in the prompt list' + + +async def test_judge_input_output_binary_content_mock(mocker: MockerFixture, image_content: BinaryContent): + """Test judge_input_output function with mocked agent.""" + # Mock the agent run method + mock_result = mocker.MagicMock() + mock_result.output = GradingOutput(reason='Test passed', pass_=True, score=1.0) + mock_run = mocker.patch('pydantic_ai.Agent.run', return_value=mock_result) + + result = await judge_input_output(image_content, 'Hello world', 'Output contains input') + assert isinstance(result, GradingOutput) + assert result.reason == 'Test passed' + assert result.pass_ is True + assert result.score == 1.0 + + # Verify the agent was called with correct prompt + mock_run.assert_called_once() + raw_prompt = mock_run.call_args[0][0] + + # 1) It must be a list + assert isinstance(raw_prompt, list), 'Expected prompt to be a list when passing binary' + + # 2) The BinaryContent you passed in should be one of the elements + assert image_content in raw_prompt, 'Expected the exact BinaryContent instance to be in the prompt list' + + @pytest.mark.anyio async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture): """Test judge_input_output function with model_settings and mocked agent.""" @@ -172,7 +221,7 @@ async def test_judge_input_output_with_model_settings_mock(mocker: MockerFixture @pytest.mark.anyio -async def test_judge_input_output_expected_mock(mocker: MockerFixture): +async def test_judge_input_output_expected_mock(mocker: MockerFixture, image_content: BinaryContent): """Test judge_input_output_expected function with mocked agent.""" # Mock the agent run method mock_result = mocker.MagicMock() @@ -187,16 +236,29 @@ async def test_judge_input_output_expected_mock(mocker: MockerFixture): assert result.score == 1.0 # Verify the agent was called with correct prompt - mock_run.assert_called_once() call_args = mock_run.call_args[0] assert '' in call_args[0] assert '\nHello\n' in call_args[0] assert '\nHello world\n' in call_args[0] assert '\nOutput contains input\n' in call_args[0] + result = await judge_input_output_expected(image_content, 'Hello world', 'Hello', 'Output contains input') + assert isinstance(result, GradingOutput) + assert result.reason == 'Test passed' + assert result.pass_ is True + assert result.score == 1.0 + + call_args = mock_run.call_args[0] + assert image_content in call_args[0] + assert '\nHello\n' in call_args[0] + assert '\nHello world\n' in call_args[0] + assert '\nOutput contains input\n' in call_args[0] + @pytest.mark.anyio -async def test_judge_input_output_expected_with_model_settings_mock(mocker: MockerFixture): +async def test_judge_input_output_expected_with_model_settings_mock( + mocker: MockerFixture, image_content: BinaryContent +): """Test judge_input_output_expected function with model_settings and mocked agent.""" mock_result = mocker.MagicMock() mock_result.output = GradingOutput(reason='Test passed with settings', pass_=True, score=1.0) @@ -216,7 +278,6 @@ async def test_judge_input_output_expected_with_model_settings_mock(mocker: Mock assert result.pass_ is True assert result.score == 1.0 - mock_run.assert_called_once() call_args, call_kwargs = mock_run.call_args assert '' in call_args[0] assert '\nHello\n' in call_args[0] @@ -226,6 +287,108 @@ async def test_judge_input_output_expected_with_model_settings_mock(mocker: Mock # Check if 'model' kwarg is passed, its value will be the default model or None assert 'model' in call_kwargs + result = await judge_input_output_expected( + image_content, + 'Hello world with settings', + 'Hello', + 'Output contains input with settings', + model_settings=test_model_settings, + ) + + assert isinstance(result, GradingOutput) + assert result.reason == 'Test passed with settings' + assert result.pass_ is True + assert result.score == 1.0 + + call_args, call_kwargs = mock_run.call_args + assert image_content in call_args[0] + assert '\nHello\n' in call_args[0] + assert '\nHello world with settings\n' in call_args[0] + assert '\nOutput contains input with settings\n' in call_args[0] + assert call_kwargs['model_settings'] == test_model_settings + # Check if 'model' kwarg is passed, its value will be the default model or None + assert 'model' in call_kwargs + + result = await judge_input_output_expected( + 123, + 'Hello world with settings', + 'Hello', + 'Output contains input with settings', + model_settings=test_model_settings, + ) + + assert isinstance(result, GradingOutput) + assert result.reason == 'Test passed with settings' + assert result.pass_ is True + assert result.score == 1.0 + + call_args, call_kwargs = mock_run.call_args + + assert call_args == snapshot( + ( + [ + '', + """\ + +Hello world with settings +\ +""", + """\ + +Output contains input with settings +\ +""", + """\ + +Hello +\ +""", + ], + ) + ) + + result = await judge_input_output_expected( + [123], + 'Hello world with settings', + 'Hello', + 'Output contains input with settings', + model_settings=test_model_settings, + ) + + assert isinstance(result, GradingOutput) + assert result.reason == 'Test passed with settings' + assert result.pass_ is True + assert result.score == 1.0 + + call_args, call_kwargs = mock_run.call_args + + assert call_args == snapshot( + ( + [ + '', + """\ + +Hello world with settings +\ +""", + """\ + +Output contains input with settings +\ +""", + """\ + +Hello +\ +""", + ], + ) + ) + @pytest.mark.anyio async def test_judge_output_expected_mock(mocker: MockerFixture): @@ -243,7 +406,6 @@ async def test_judge_output_expected_mock(mocker: MockerFixture): assert result.score == 1.0 # Verify the agent was called with correct prompt - mock_run.assert_called_once() call_args = mock_run.call_args[0] assert '
(Union)") --- ModelRequest ModelRequest("ModelRequest(parts=list[...])") --- ModelMessage ModelResponsePart("ModelResponsePart
(Union)") --- ModelResponse From 94b43053ed648e98e1be78abdb73c82ae9d4423f Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 14:54:14 -0600 Subject: [PATCH 69/89] Fix docs build failure by adding MoonshotAIProvider to API docs (#2304) --- docs/api/providers.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/docs/api/providers.md b/docs/api/providers.md index 024e2f6c20..38cc966fd5 100644 --- a/docs/api/providers.md +++ b/docs/api/providers.md @@ -35,3 +35,5 @@ ::: pydantic_ai.providers.vercel.VercelProvider ::: pydantic_ai.providers.huggingface.HuggingFaceProvider + +::: pydantic_ai.providers.moonshotai.MoonshotAIProvider From 4104acaa19e2aac676385883bddc6a9c087fff8e Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 15:08:39 -0600 Subject: [PATCH 70/89] Fix initial tool call args not being streamed with AG-UI (#2303) --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 13 ++++++++++--- tests/test_ag_ui.py | 7 ++++--- 2 files changed, 14 insertions(+), 6 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index c91c0bef8f..2dbe6faf3a 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -380,14 +380,14 @@ async def _handle_model_request_event( yield TextMessageStartEvent( message_id=message_id, ) - stream_ctx.part_end = TextMessageEndEvent( - message_id=message_id, - ) if part.content: # pragma: no branch yield TextMessageContentEvent( message_id=message_id, delta=part.content, ) + stream_ctx.part_end = TextMessageEndEvent( + message_id=message_id, + ) elif isinstance(part, ToolCallPart): # pragma: no branch message_id = stream_ctx.message_id or stream_ctx.new_message_id() yield ToolCallStartEvent( @@ -395,6 +395,11 @@ async def _handle_model_request_event( tool_call_name=part.tool_name, parent_message_id=message_id, ) + if part.args: + yield ToolCallArgsEvent( + tool_call_id=part.tool_call_id, + delta=part.args if isinstance(part.args, str) else json.dumps(part.args), + ) stream_ctx.part_end = ToolCallEndEvent( tool_call_id=part.tool_call_id, ) @@ -403,6 +408,8 @@ async def _handle_model_request_event( yield ThinkingTextMessageStartEvent( type=EventType.THINKING_TEXT_MESSAGE_START, ) + # Always send the content even if it's empty, as it may be + # used to indicate the start of thinking. yield ThinkingTextMessageContentEvent( type=EventType.THINKING_TEXT_MESSAGE_CONTENT, delta=part.content, diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 80324521a5..423011b76d 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -303,8 +303,8 @@ async def stream_function( ) -> AsyncIterator[DeltaToolCalls | str]: if len(messages) == 1: # First call - make a tool call - yield {0: DeltaToolCall(name='get_weather')} - yield {0: DeltaToolCall(json_args='{"location": "Paris"}')} + yield {0: DeltaToolCall(name='get_weather', json_args='{"location": ')} + yield {0: DeltaToolCall(json_args='"Paris"}')} else: # Second call - return text result yield '{"get_weather": "Tool result"}' @@ -369,8 +369,9 @@ async def stream_function( { 'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, - 'delta': '{"location": "Paris"}', + 'delta': '{"location": ', }, + {'type': 'TOOL_CALL_ARGS', 'toolCallId': tool_call_id, 'delta': '"Paris"}'}, {'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id}, { 'type': 'RUN_FINISHED', From 41dd069aff984cd0560baf8335e8511d0d0392c3 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Thu, 24 Jul 2025 15:16:00 -0600 Subject: [PATCH 71/89] Ignore leading whitespace when streaming text, fixing run_stream + Ollama + Qwen3 (#2294) --- .../pydantic_ai/_parts_manager.py | 29 ++++++----------- pydantic_ai_slim/pydantic_ai/messages.py | 2 +- .../pydantic_ai/models/anthropic.py | 14 ++++++-- .../pydantic_ai/models/bedrock.py | 6 ++-- .../pydantic_ai/models/function.py | 6 ++-- pydantic_ai_slim/pydantic_ai/models/gemini.py | 6 +++- pydantic_ai_slim/pydantic_ai/models/google.py | 4 ++- .../pydantic_ai/models/mistral.py | 4 ++- pydantic_ai_slim/pydantic_ai/models/openai.py | 6 +++- pydantic_ai_slim/pydantic_ai/models/test.py | 8 +++-- tests/models/test_bedrock.py | 8 +---- tests/models/test_groq.py | 19 ++++++----- tests/models/test_huggingface.py | 19 ++++++----- tests/models/test_instrumented.py | 8 +++-- tests/models/test_openai.py | 32 +++++++++++++++++++ tests/test_a2a.py | 18 ++++------- 16 files changed, 116 insertions(+), 73 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_parts_manager.py b/pydantic_ai_slim/pydantic_ai/_parts_manager.py index 299a4297d2..223e2de3b5 100644 --- a/pydantic_ai_slim/pydantic_ai/_parts_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_parts_manager.py @@ -15,7 +15,7 @@ from collections.abc import Hashable from dataclasses import dataclass, field, replace -from typing import Any, Literal, Union, overload +from typing import Any, Union from pydantic_ai._thinking_part import END_THINK_TAG, START_THINK_TAG from pydantic_ai.exceptions import UnexpectedModelBehavior @@ -67,23 +67,6 @@ def get_parts(self) -> list[ModelResponsePart]: """ return [p for p in self._parts if not isinstance(p, ToolCallPartDelta)] - @overload - def handle_text_delta( - self, - *, - vendor_part_id: VendorId | None, - content: str, - ) -> ModelResponseStreamEvent: ... - - @overload - def handle_text_delta( - self, - *, - vendor_part_id: VendorId, - content: str, - extract_think_tags: Literal[True], - ) -> ModelResponseStreamEvent | None: ... - def handle_text_delta( self, *, @@ -105,7 +88,9 @@ def handle_text_delta( extract_think_tags: Whether to extract `` tags from the text content and handle them as thinking parts. Returns: - A `PartStartEvent` if a new part was created, or a `PartDeltaEvent` if an existing part was updated. + - A `PartStartEvent` if a new part was created. + - A `PartDeltaEvent` if an existing part was updated. + - `None` if no new event is emitted (e.g., the first text part was all whitespace). Raises: UnexpectedModelBehavior: If attempting to apply text content to a part that is not a TextPart. @@ -144,6 +129,12 @@ def handle_text_delta( return self.handle_thinking_delta(vendor_part_id=vendor_part_id, content='') if existing_text_part_and_index is None: + # If the first text delta is all whitespace, don't emit a new part yet. + # This is a workaround for models that emit `\n\n\n` ahead of tool calls (e.g. Ollama + Qwen3), + # which we don't want to end up treating as a final result. + if content.isspace(): + return None + # There is no existing text part that should be updated, so create a new one new_part_index = len(self._parts) part = TextPart(content=content) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 4f98d995a3..379d70efd7 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -690,7 +690,7 @@ class ThinkingPart: def has_content(self) -> bool: """Return `True` if the thinking content is non-empty.""" - return bool(self.content) # pragma: no cover + return bool(self.content) __repr__ = _utils.dataclasses_no_defaults_repr diff --git a/pydantic_ai_slim/pydantic_ai/models/anthropic.py b/pydantic_ai_slim/pydantic_ai/models/anthropic.py index a627415689..02f9111c2d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/anthropic.py +++ b/pydantic_ai_slim/pydantic_ai/models/anthropic.py @@ -470,7 +470,7 @@ class AnthropicStreamedResponse(StreamedResponse): _response: AsyncIterable[BetaRawMessageStreamEvent] _timestamp: datetime - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 current_block: BetaContentBlock | None = None async for event in self._response: @@ -479,7 +479,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(event, BetaRawContentBlockStartEvent): current_block = event.content_block if isinstance(current_block, BetaTextBlock) and current_block.text: - yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=current_block.text) + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id='content', content=current_block.text + ) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif isinstance(current_block, BetaThinkingBlock): yield self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', @@ -498,7 +502,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: elif isinstance(event, BetaRawContentBlockDeltaEvent): if isinstance(event.delta, BetaTextDelta): - yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=event.delta.text) + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id='content', content=event.delta.text + ) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif isinstance(event.delta, BetaThinkingDelta): yield self._parts_manager.handle_thinking_delta( vendor_part_id='thinking', content=event.delta.thinking diff --git a/pydantic_ai_slim/pydantic_ai/models/bedrock.py b/pydantic_ai_slim/pydantic_ai/models/bedrock.py index f16f9d1119..b63ed4e1f9 100644 --- a/pydantic_ai_slim/pydantic_ai/models/bedrock.py +++ b/pydantic_ai_slim/pydantic_ai/models/bedrock.py @@ -572,7 +572,7 @@ class BedrockStreamedResponse(StreamedResponse): _event_stream: EventStream[ConverseStreamOutputTypeDef] _timestamp: datetime = field(default_factory=_utils.now_utc) - async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: + async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: # noqa: C901 """Return an async iterator of [`ModelResponseStreamEvent`][pydantic_ai.messages.ModelResponseStreamEvent]s. This method should be implemented by subclasses to translate the vendor-specific stream of events into @@ -618,7 +618,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: UserWarning, ) if 'text' in delta: - yield self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text']) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=index, content=delta['text']) + if maybe_event is not None: + yield maybe_event if 'toolUse' in delta: tool_use = delta['toolUse'] maybe_event = self._parts_manager.handle_tool_call_delta( diff --git a/pydantic_ai_slim/pydantic_ai/models/function.py b/pydantic_ai_slim/pydantic_ai/models/function.py index 4efc6d7109..c48873f046 100644 --- a/pydantic_ai_slim/pydantic_ai/models/function.py +++ b/pydantic_ai_slim/pydantic_ai/models/function.py @@ -264,7 +264,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if isinstance(item, str): response_tokens = _estimate_string_tokens(item) self._usage += usage.Usage(response_tokens=response_tokens, total_tokens=response_tokens) - yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=item) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif isinstance(item, dict) and item: for dtc_index, delta in item.items(): if isinstance(delta, DeltaThinkingPart): @@ -286,7 +288,7 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: args=delta.json_args, tool_call_id=delta.tool_call_id, ) - if maybe_event is not None: + if maybe_event is not None: # pragma: no branch yield maybe_event else: assert_never(delta) diff --git a/pydantic_ai_slim/pydantic_ai/models/gemini.py b/pydantic_ai_slim/pydantic_ai/models/gemini.py index 7c371a9439..4ac07f8ada 100644 --- a/pydantic_ai_slim/pydantic_ai/models/gemini.py +++ b/pydantic_ai_slim/pydantic_ai/models/gemini.py @@ -438,7 +438,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if 'text' in gemini_part: # Using vendor_part_id=None means we can produce multiple text parts if their deltas are sprinkled # amongst the tool call deltas - yield self._parts_manager.handle_text_delta(vendor_part_id=None, content=gemini_part['text']) + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id=None, content=gemini_part['text'] + ) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif 'function_call' in gemini_part: # Here, we assume all function_call parts are complete and don't have deltas. diff --git a/pydantic_ai_slim/pydantic_ai/models/google.py b/pydantic_ai_slim/pydantic_ai/models/google.py index ad5da243c1..082f5ba566 100644 --- a/pydantic_ai_slim/pydantic_ai/models/google.py +++ b/pydantic_ai_slim/pydantic_ai/models/google.py @@ -458,7 +458,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: if part.thought: yield self._parts_manager.handle_thinking_delta(vendor_part_id='thinking', content=part.text) else: - yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=part.text) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif part.function_call: maybe_event = self._parts_manager.handle_tool_call_delta( vendor_part_id=uuid4(), diff --git a/pydantic_ai_slim/pydantic_ai/models/mistral.py b/pydantic_ai_slim/pydantic_ai/models/mistral.py index 0104e2055e..ca73558bca 100644 --- a/pydantic_ai_slim/pydantic_ai/models/mistral.py +++ b/pydantic_ai_slim/pydantic_ai/models/mistral.py @@ -601,7 +601,9 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: tool_call_id=maybe_tool_call_part.tool_call_id, ) else: - yield self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id='content', content=text) + if maybe_event is not None: # pragma: no branch + yield maybe_event # Handle the explicit tool calls for index, dtc in enumerate(choice.delta.tool_calls or []): diff --git a/pydantic_ai_slim/pydantic_ai/models/openai.py b/pydantic_ai_slim/pydantic_ai/models/openai.py index 1881862cd9..35dca2e03d 100644 --- a/pydantic_ai_slim/pydantic_ai/models/openai.py +++ b/pydantic_ai_slim/pydantic_ai/models/openai.py @@ -1144,7 +1144,11 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: ) elif isinstance(chunk, responses.ResponseTextDeltaEvent): - yield self._parts_manager.handle_text_delta(vendor_part_id=chunk.content_index, content=chunk.delta) + maybe_event = self._parts_manager.handle_text_delta( + vendor_part_id=chunk.content_index, content=chunk.delta + ) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif isinstance(chunk, responses.ResponseTextDoneEvent): pass # there's nothing we need to do here diff --git a/pydantic_ai_slim/pydantic_ai/models/test.py b/pydantic_ai_slim/pydantic_ai/models/test.py index a80d551ff1..eebe00d440 100644 --- a/pydantic_ai_slim/pydantic_ai/models/test.py +++ b/pydantic_ai_slim/pydantic_ai/models/test.py @@ -269,10 +269,14 @@ async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: mid = len(text) // 2 words = [text[:mid], text[mid:]] self._usage += _get_string_usage('') - yield self._parts_manager.handle_text_delta(vendor_part_id=i, content='') + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content='') + if maybe_event is not None: # pragma: no branch + yield maybe_event for word in words: self._usage += _get_string_usage(word) - yield self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=i, content=word) + if maybe_event is not None: # pragma: no branch + yield maybe_event elif isinstance(part, ToolCallPart): yield self._parts_manager.handle_tool_call_part( vendor_part_id=i, tool_name=part.tool_name, args=part.args, tool_call_id=part.tool_call_id diff --git a/tests/models/test_bedrock.py b/tests/models/test_bedrock.py index 228f8a9f0f..fad3530758 100644 --- a/tests/models/test_bedrock.py +++ b/tests/models/test_bedrock.py @@ -1127,15 +1127,9 @@ async def test_bedrock_model_thinking_part_stream(allow_model_requests: None, be PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' all together.\n')), PartStartEvent( index=1, - part=TextPart( - content="""\ - - -""" - ), + part=TextPart(content='Crossing the'), ), FinalResultEvent(tool_name=None, tool_call_id=None), - PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Crossing the')), PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' street safely involves')), PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' careful')), PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' observation')), diff --git a/tests/models/test_groq.py b/tests/models/test_groq.py index 152f4358b0..4c5de38ee7 100644 --- a/tests/models/test_groq.py +++ b/tests/models/test_groq.py @@ -934,17 +934,16 @@ async def test_groq_model_thinking_part_iter(allow_model_requests: None, groq_ap 7: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' come')), 8: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' up')), 9: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' with')), - 589: PartStartEvent(index=1, part=TextPart(content='\n\n')), + 589: PartStartEvent(index=1, part=TextPart(content='**')), 590: FinalResultEvent(tool_name=None, tool_call_id=None), - 591: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='**')), - 592: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Ur')), - 593: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ugu')), - 594: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ayan')), - 595: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' Alf')), - 596: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='aj')), - 597: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ores')), - 598: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' Recipe')), + 591: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Ur')), + 592: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ugu')), + 593: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ayan')), + 594: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' Alf')), + 595: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='aj')), + 596: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ores')), + 597: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' Recipe')), }, - length=997, + length=996, ) ) diff --git a/tests/models/test_huggingface.py b/tests/models/test_huggingface.py index 1bdc60bee5..b7351d574c 100644 --- a/tests/models/test_huggingface.py +++ b/tests/models/test_huggingface.py @@ -1031,17 +1031,16 @@ async def test_hf_model_thinking_part_iter(allow_model_requests: None, huggingfa 5: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' user')), 6: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' is')), 7: PartDeltaEvent(index=0, delta=ThinkingPartDelta(content_delta=' asking')), - 413: PartStartEvent(index=1, part=TextPart(content='\n\n')), + 413: PartStartEvent(index=1, part=TextPart(content='Cross')), 414: FinalResultEvent(tool_name=None, tool_call_id=None), - 415: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='Cross')), - 416: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ing')), - 417: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' the')), - 418: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' street')), - 419: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' safely')), - 420: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' requires')), - 421: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' attent')), - 422: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='iveness')), + 415: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='ing')), + 416: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' the')), + 417: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' street')), + 418: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' safely')), + 419: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' requires')), + 420: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta=' attent')), + 421: PartDeltaEvent(index=1, delta=TextPartDelta(content_delta='iveness')), }, - length=1060, + length=1059, ) ) diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index b952bf7166..c926e1c3aa 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -96,8 +96,12 @@ async def request_stream( class MyResponseStream(StreamedResponse): async def _get_event_iterator(self) -> AsyncIterator[ModelResponseStreamEvent]: self._usage = Usage(request_tokens=300, response_tokens=400) - yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') - yield self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text1') + if maybe_event is not None: # pragma: no branch + yield maybe_event + maybe_event = self._parts_manager.handle_text_delta(vendor_part_id=0, content='text2') + if maybe_event is not None: # pragma: no branch + yield maybe_event @property def model_name(self) -> str: diff --git a/tests/models/test_openai.py b/tests/models/test_openai.py index 3adc67abe0..880c70eff3 100644 --- a/tests/models/test_openai.py +++ b/tests/models/test_openai.py @@ -611,6 +611,38 @@ async def test_stream_tool_call_with_empty_text(allow_model_requests: None): assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'}) +async def test_stream_text_empty_think_tag_and_text_before_tool_call(allow_model_requests: None): + # Ollama + Qwen3 will emit `\n\n\n` ahead of tool calls, + # which we don't want to end up treating as a final result. + stream = [ + text_chunk(''), + text_chunk('\n'), + text_chunk(''), + text_chunk('\n\n'), + struc_chunk('final_result', None), + struc_chunk(None, '{"first": "One'), + struc_chunk(None, '", "second": "Two"'), + struc_chunk(None, '}'), + chunk([]), + ] + mock_client = MockOpenAI.create_mock_stream(stream) + m = OpenAIModel('gpt-4o', provider=OpenAIProvider(openai_client=mock_client)) + agent = Agent(m, output_type=[str, MyTypedDict]) + + async with agent.run_stream('') as result: + assert not result.is_complete + assert [c async for c in result.stream(debounce_by=None)] == snapshot( + [ + {}, + {'first': 'One'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + {'first': 'One', 'second': 'Two'}, + ] + ) + assert await result.get_output() == snapshot({'first': 'One', 'second': 'Two'}) + + async def test_no_content(allow_model_requests: None): stream = [chunk([ChoiceDelta()]), chunk([ChoiceDelta()])] mock_client = MockOpenAI.create_mock_stream(stream) diff --git a/tests/test_a2a.py b/tests/test_a2a.py index 535e3b1e91..c13d1da54d 100644 --- a/tests/test_a2a.py +++ b/tests/test_a2a.py @@ -593,10 +593,10 @@ def track_messages(messages: list[ModelMessage], info: AgentInfo) -> ModelRespon assert result2['context_id'] == context_id # Wait for second task to complete - await anyio.sleep(0.1) - task2 = await a2a_client.get_task(task2_id) - assert 'result' in task2 - assert task2['result']['status']['state'] == 'completed' + while task2 := await a2a_client.get_task(task2_id): # pragma: no branch + if 'result' in task2 and task2['result']['status']['state'] == 'completed': + break + await anyio.sleep(0.1) # Verify the model received the full history on the second call assert len(messages_received) == 2 @@ -800,14 +800,10 @@ async def test_a2a_multiple_messages(): } ) - task = None - tries = 0 - while tries < 10: # pragma: no branch - await anyio.sleep(0.1) - task = await a2a_client.get_task(task_id) - tries += 1 - if 'result' in task and task['result']['status']['state'] == 'completed': # pragma: no branch + while task := await a2a_client.get_task(task_id): # pragma: no branch + if 'result' in task and task['result']['status']['state'] == 'completed': break + await anyio.sleep(0.1) assert task == snapshot( { From ff0f6a9ccc7a5439b924e3db43db5bae36d75c75 Mon Sep 17 00:00:00 2001 From: Kevin Turcios Date: Thu, 24 Jul 2025 14:47:20 -0700 Subject: [PATCH 72/89] refine changes --- pydantic_ai_slim/pydantic_ai/ext/langchain.py | 15 +++++---------- 1 file changed, 5 insertions(+), 10 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ext/langchain.py b/pydantic_ai_slim/pydantic_ai/ext/langchain.py index cd6f7b2f44..7f3faeafbe 100644 --- a/pydantic_ai_slim/pydantic_ai/ext/langchain.py +++ b/pydantic_ai_slim/pydantic_ai/ext/langchain.py @@ -38,9 +38,8 @@ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: function_name = langchain_tool.name function_description = langchain_tool.description inputs = langchain_tool.args.copy() - required = [] - defaults = {} - # Only one iteration through inputs for both required and defaults + required: list[str] = [] + defaults: dict[str, Any] = {} for name, detail in inputs.items(): if 'default' in detail: defaults[name] = detail['default'] @@ -54,15 +53,11 @@ def tool_from_langchain(langchain_tool: LangChainTool) -> Tool: if 'additionalProperties' not in schema: schema['additionalProperties'] = False - # restructure arguments and merge efficiently using only the necessary step + # restructures the arguments to match langchain tool run def proxy(*args: Any, **kwargs: Any) -> str: assert not args, 'This should always be called with kwargs' - if defaults: - merged = defaults.copy() - merged.update(kwargs) - return langchain_tool.run(merged) - else: - return langchain_tool.run(kwargs) + kwargs = defaults | kwargs + return langchain_tool.run(kwargs) return Tool.from_schema( function=proxy, From 5cf372ada6a401b64a2e6c30d94f2c41e9ddbce0 Mon Sep 17 00:00:00 2001 From: Christian Hartung Date: Fri, 25 Jul 2025 14:25:29 -0300 Subject: [PATCH 73/89] fix: close initialized MCP server if any MCP server fails to initalize (#2312) --- pydantic_ai_slim/pydantic_ai/agent.py | 8 ++-- pydantic_ai_slim/pydantic_ai/mcp.py | 37 +++++++++---------- .../pydantic_ai/toolsets/combined.py | 7 ++-- tests/test_mcp.py | 14 +++++++ tests/test_toolsets.py | 25 +++++++++++++ 5 files changed, 66 insertions(+), 25 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/agent.py b/pydantic_ai_slim/pydantic_ai/agent.py index 9f6348c6c9..5f22d73294 100644 --- a/pydantic_ai_slim/pydantic_ai/agent.py +++ b/pydantic_ai_slim/pydantic_ai/agent.py @@ -1792,9 +1792,11 @@ async def __aenter__(self) -> Self: """ async with self._enter_lock: if self._entered_count == 0: - self._exit_stack = AsyncExitStack() - toolset = self._get_toolset() - await self._exit_stack.enter_async_context(toolset) + async with AsyncExitStack() as exit_stack: + toolset = self._get_toolset() + await exit_stack.enter_async_context(toolset) + + self._exit_stack = exit_stack.pop_all() self._entered_count += 1 return self diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index c84f4b10bc..77d53f0800 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -201,25 +201,24 @@ async def __aenter__(self) -> Self: """ async with self._enter_lock: if self._running_count == 0: - self._exit_stack = AsyncExitStack() - - self._read_stream, self._write_stream = await self._exit_stack.enter_async_context( - self.client_streams() - ) - client = ClientSession( - read_stream=self._read_stream, - write_stream=self._write_stream, - sampling_callback=self._sampling_callback if self.allow_sampling else None, - logging_callback=self.log_handler, - read_timeout_seconds=timedelta(seconds=self.read_timeout), - ) - self._client = await self._exit_stack.enter_async_context(client) - - with anyio.fail_after(self.timeout): - await self._client.initialize() - - if log_level := self.log_level: - await self._client.set_logging_level(log_level) + async with AsyncExitStack() as exit_stack: + self._read_stream, self._write_stream = await exit_stack.enter_async_context(self.client_streams()) + client = ClientSession( + read_stream=self._read_stream, + write_stream=self._write_stream, + sampling_callback=self._sampling_callback if self.allow_sampling else None, + logging_callback=self.log_handler, + read_timeout_seconds=timedelta(seconds=self.read_timeout), + ) + self._client = await exit_stack.enter_async_context(client) + + with anyio.fail_after(self.timeout): + await self._client.initialize() + + if log_level := self.log_level: + await self._client.set_logging_level(log_level) + + self._exit_stack = exit_stack.pop_all() self._running_count += 1 return self diff --git a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py index 4b1511fae1..d2ddaa1258 100644 --- a/pydantic_ai_slim/pydantic_ai/toolsets/combined.py +++ b/pydantic_ai_slim/pydantic_ai/toolsets/combined.py @@ -43,9 +43,10 @@ def __post_init__(self): async def __aenter__(self) -> Self: async with self._enter_lock: if self._entered_count == 0: - self._exit_stack = AsyncExitStack() - for toolset in self.toolsets: - await self._exit_stack.enter_async_context(toolset) + async with AsyncExitStack() as exit_stack: + for toolset in self.toolsets: + await exit_stack.enter_async_context(toolset) + self._exit_stack = exit_stack.pop_all() self._entered_count += 1 return self diff --git a/tests/test_mcp.py b/tests/test_mcp.py index de77b3587e..1021b31512 100644 --- a/tests/test_mcp.py +++ b/tests/test_mcp.py @@ -91,6 +91,20 @@ async def test_reentrant_context_manager(): pass +async def test_context_manager_initialization_error() -> None: + """Test if streams are closed if client fails to initialize.""" + server = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + from mcp.client.session import ClientSession + + with patch.object(ClientSession, 'initialize', side_effect=Exception): + with pytest.raises(Exception): + async with server: + pass + + assert server._read_stream._closed # pyright: ignore[reportPrivateUsage] + assert server._write_stream._closed # pyright: ignore[reportPrivateUsage] + + async def test_stdio_server_with_tool_prefix(run_context: RunContext[int]): server = MCPServerStdio('python', ['-m', 'tests.mcp_server'], tool_prefix='foo') async with server: diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index eac0dc78a7..f188d3141a 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -3,6 +3,7 @@ import re from dataclasses import dataclass, replace from typing import TypeVar +from unittest.mock import AsyncMock import pytest from inline_snapshot import snapshot @@ -469,3 +470,27 @@ async def test_context_manager(): async with toolset: assert server1.is_running assert server2.is_running + + +class InitializationError(Exception): + pass + + +async def test_context_manager_failed_initialization(): + """Test if MCP servers stop if any MCP server fails to initialize.""" + try: + from pydantic_ai.mcp import MCPServerStdio + except ImportError: # pragma: lax no cover + pytest.skip('mcp is not installed') + + server1 = MCPServerStdio('python', ['-m', 'tests.mcp_server']) + server2 = AsyncMock() + server2.__aenter__.side_effect = InitializationError + + toolset = CombinedToolset([server1, server2]) + + with pytest.raises(InitializationError): + async with toolset: + pass + + assert server1.is_running is False From 494146807d40dab262695580c75f70d5cb91f7e3 Mon Sep 17 00:00:00 2001 From: David Montague <35119617+dmontagu@users.noreply.github.com> Date: Fri, 25 Jul 2025 10:34:44 -0700 Subject: [PATCH 74/89] Add tenacity utilities/integration for improved retry handling (#2282) Co-authored-by: Douwe Maan --- docs/api/retries.md | 3 + docs/retries.md | 338 ++++++++++++++ mkdocs.yml | 2 + pydantic_ai_slim/pydantic_ai/retries.py | 249 +++++++++++ pydantic_ai_slim/pyproject.toml | 2 + pyproject.toml | 2 +- tests/test_tenacity.py | 572 ++++++++++++++++++++++++ uv.lock | 10 +- 8 files changed, 1174 insertions(+), 4 deletions(-) create mode 100644 docs/api/retries.md create mode 100644 docs/retries.md create mode 100644 pydantic_ai_slim/pydantic_ai/retries.py create mode 100644 tests/test_tenacity.py diff --git a/docs/api/retries.md b/docs/api/retries.md new file mode 100644 index 0000000000..a5eb4931ea --- /dev/null +++ b/docs/api/retries.md @@ -0,0 +1,3 @@ +# `pydantic_ai.retries` + +::: pydantic_ai.retries diff --git a/docs/retries.md b/docs/retries.md new file mode 100644 index 0000000000..31449d3ec3 --- /dev/null +++ b/docs/retries.md @@ -0,0 +1,338 @@ +# HTTP Request Retries + +Pydantic AI provides retry functionality for HTTP requests made by model providers through custom HTTP transports. +This is particularly useful for handling transient failures like rate limits, network timeouts, or temporary server errors. + +## Overview + +The retry functionality is built on top of the [tenacity](https://github.com/jd/tenacity) library and integrates +seamlessly with httpx clients. You can configure retry behavior for any provider that accepts a custom HTTP client. + +## Installation + +To use the retry transports, you need to install `tenacity`, which you can do via the `retries` dependency group: + +```bash +pip/uv-add 'pydantic-ai-slim[retries]' +``` + +## Usage Example + +Here's an example of adding retry functionality with smart retry handling: + +```python {title="smart_retry_example.py"} +from httpx import AsyncClient, HTTPStatusError +from tenacity import ( + AsyncRetrying, + stop_after_attempt, + wait_exponential, + retry_if_exception_type +) +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after +from pydantic_ai.providers.openai import OpenAIProvider + +def create_retrying_client(): + """Create a client with smart retry handling for multiple error types.""" + + def should_retry_status(response): + """Raise exceptions for retryable HTTP status codes.""" + if response.status_code in (429, 502, 503, 504): + response.raise_for_status() # This will raise HTTPStatusError + + transport = AsyncTenacityTransport( + controller=AsyncRetrying( + # Retry on HTTP errors and connection issues + retry=retry_if_exception_type((HTTPStatusError, ConnectionError)), + # Smart waiting: respects Retry-After headers, falls back to exponential backoff + wait=wait_retry_after( + fallback_strategy=wait_exponential(multiplier=1, max=60), + max_wait=300 + ), + # Stop after 5 attempts + stop=stop_after_attempt(5), + # Re-raise the last exception if all retries fail + reraise=True + ), + validate_response=should_retry_status + ) + return AsyncClient(transport=transport) + +# Use the retrying client with a model +client = create_retrying_client() +model = OpenAIModel('gpt-4o', provider=OpenAIProvider(http_client=client)) +agent = Agent(model) +``` + +## Wait Strategies + +### wait_retry_after + +The `wait_retry_after` function is a smart wait strategy that automatically respects HTTP `Retry-After` headers: + +```python {title="wait_strategy_example.py"} +from pydantic_ai.retries import wait_retry_after +from tenacity import wait_exponential + +# Basic usage - respects Retry-After headers, falls back to exponential backoff +wait_strategy_1 = wait_retry_after() + +# Custom configuration +wait_strategy_2 = wait_retry_after( + fallback_strategy=wait_exponential(multiplier=2, max=120), + max_wait=600 # Never wait more than 10 minutes +) +``` + +This wait strategy: +- Automatically parses `Retry-After` headers from HTTP 429 responses +- Supports both seconds format (`"30"`) and HTTP date format (`"Wed, 21 Oct 2015 07:28:00 GMT"`) +- Falls back to your chosen strategy when no header is present +- Respects the `max_wait` limit to prevent excessive delays + +## Transport Classes + +### AsyncTenacityTransport + +For asynchronous HTTP clients (recommended for most use cases): + +```python {title="async_transport_example.py"} +from httpx import AsyncClient +from tenacity import AsyncRetrying, stop_after_attempt +from pydantic_ai.retries import AsyncTenacityTransport + +# Create the basic components +async_retrying = AsyncRetrying(stop=stop_after_attempt(3), reraise=True) + +def validator(response): + """Treat responses with HTTP status 4xx/5xx as failures that need to be retried. + Without a response validator, only network errors and timeouts will result in a retry. + """ + response.raise_for_status() + +# Create the transport +transport = AsyncTenacityTransport( + controller=async_retrying, # AsyncRetrying instance + validate_response=validator # Optional response validator +) + +# Create a client using the transport: +client = AsyncClient(transport=transport) +``` + +### TenacityTransport + +For synchronous HTTP clients: + +```python {title="sync_transport_example.py"} +from httpx import Client +from tenacity import Retrying, stop_after_attempt +from pydantic_ai.retries import TenacityTransport + +# Create the basic components +retrying = Retrying(stop=stop_after_attempt(3), reraise=True) + +def validator(response): + """Treat responses with HTTP status 4xx/5xx as failures that need to be retried. + Without a response validator, only network errors and timeouts will result in a retry. + """ + response.raise_for_status() + +# Create the transport +transport = TenacityTransport( + controller=retrying, # Retrying instance + validate_response=validator # Optional response validator +) + +# Create a client using the transport +client = Client(transport=transport) +``` + +## Common Retry Patterns + +### Rate Limit Handling with Retry-After Support + +```python {title="rate_limit_handling.py"} +from httpx import AsyncClient, HTTPStatusError +from tenacity import AsyncRetrying, stop_after_attempt, retry_if_exception_type, wait_exponential +from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after + +def create_rate_limit_client(): + """Create a client that respects Retry-After headers from rate limiting responses.""" + transport = AsyncTenacityTransport( + controller=AsyncRetrying( + retry=retry_if_exception_type(HTTPStatusError), + wait=wait_retry_after( + fallback_strategy=wait_exponential(multiplier=1, max=60), + max_wait=300 # Don't wait more than 5 minutes + ), + stop=stop_after_attempt(10), + reraise=True + ), + validate_response=lambda r: r.raise_for_status() # Raises HTTPStatusError for 4xx/5xx + ) + return AsyncClient(transport=transport) + +# Example usage +client = create_rate_limit_client() +# Client is now ready to use with any HTTP requests and will respect Retry-After headers +``` + +The `wait_retry_after` function automatically detects `Retry-After` headers in 429 (rate limit) responses and waits for the specified time. If no header is present, it falls back to exponential backoff. + +### Network Error Handling + +```python {title="network_error_handling.py"} +import httpx +from tenacity import AsyncRetrying, retry_if_exception_type, wait_exponential, stop_after_attempt +from pydantic_ai.retries import AsyncTenacityTransport + +def create_network_resilient_client(): + """Create a client that handles network errors with retries.""" + transport = AsyncTenacityTransport( + controller=AsyncRetrying( + retry=retry_if_exception_type(( + httpx.TimeoutException, + httpx.ConnectError, + httpx.ReadError + )), + wait=wait_exponential(multiplier=1, max=10), + stop=stop_after_attempt(3), + reraise=True + ) + ) + return httpx.AsyncClient(transport=transport) + +# Example usage +client = create_network_resilient_client() +# Client will now retry on timeout, connection, and read errors +``` + +### Custom Retry Logic + +```python {title="custom_retry_logic.py"} +import httpx +from tenacity import AsyncRetrying, wait_exponential, stop_after_attempt +from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after + +def create_custom_retry_client(): + """Create a client with custom retry logic.""" + def custom_retry_condition(exception): + """Custom logic to determine if we should retry.""" + if isinstance(exception, httpx.HTTPStatusError): + # Retry on server errors but not client errors + return 500 <= exception.response.status_code < 600 + return isinstance(exception, (httpx.TimeoutException, httpx.ConnectError)) + + transport = AsyncTenacityTransport( + controller=AsyncRetrying( + retry=custom_retry_condition, + # Use wait_retry_after for smart waiting on rate limits, + # with custom exponential backoff as fallback + wait=wait_retry_after( + fallback_strategy=wait_exponential(multiplier=2, max=30), + max_wait=120 + ), + stop=stop_after_attempt(5), + reraise=True + ), + validate_response=lambda r: r.raise_for_status() + ) + return httpx.AsyncClient(transport=transport) + +client = create_custom_retry_client() +# Client will retry server errors (5xx) and network errors, but not client errors (4xx) +``` + +## Using with Different Providers + +The retry transports work with any provider that accepts a custom HTTP client: + +### OpenAI + +```python {title="openai_with_retries.py" requires="smart_retry_example.py"} +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.openai import OpenAIProvider + +from smart_retry_example import create_retrying_client + +client = create_retrying_client() +model = OpenAIModel('gpt-4o', provider=OpenAIProvider(http_client=client)) +agent = Agent(model) +``` + +### Anthropic + +```python {title="anthropic_with_retries.py" requires="smart_retry_example.py"} +from pydantic_ai import Agent +from pydantic_ai.models.anthropic import AnthropicModel +from pydantic_ai.providers.anthropic import AnthropicProvider + +from smart_retry_example import create_retrying_client + +client = create_retrying_client() +model = AnthropicModel('claude-3-5-sonnet-20241022', provider=AnthropicProvider(http_client=client)) +agent = Agent(model) +``` + +### Any OpenAI-Compatible Provider + +```python {title="openai_compatible_with_retries.py" requires="smart_retry_example.py"} +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.openai import OpenAIProvider + +from smart_retry_example import create_retrying_client + +client = create_retrying_client() +model = OpenAIModel( + 'your-model-name', # Replace with actual model name + provider=OpenAIProvider( + base_url='https://api.example.com/v1', # Replace with actual API URL + api_key='your-api-key', # Replace with actual API key + http_client=client + ) +) +agent = Agent(model) +``` + +## Best Practices + +1. **Start Conservative**: Begin with a small number of retries (3-5) and reasonable wait times. + +2. **Use Exponential Backoff**: This helps avoid overwhelming servers during outages. + +3. **Set Maximum Wait Times**: Prevent indefinite delays with reasonable maximum wait times. + +4. **Handle Rate Limits Properly**: Respect `Retry-After` headers when possible. + +5. **Log Retry Attempts**: Add logging to monitor retry behavior in production. (This will be picked up by Logfire automatically if you instrument httpx.) + +6. **Consider Circuit Breakers**: For high-traffic applications, consider implementing circuit breaker patterns. + +## Error Handling + +The retry transports will re-raise the last exception if all retry attempts fail. Make sure to handle these appropriately in your application: + +```python {title="error_handling_example.py" requires="smart_retry_example.py"} +from pydantic_ai import Agent +from pydantic_ai.models.openai import OpenAIModel +from pydantic_ai.providers.openai import OpenAIProvider + +from smart_retry_example import create_retrying_client + +client = create_retrying_client() +model = OpenAIModel('gpt-4o', provider=OpenAIProvider(http_client=client)) +agent = Agent(model) +``` + +## Performance Considerations + +- Retries add latency to requests, especially with exponential backoff +- Consider the total timeout for your application when configuring retry behavior +- Monitor retry rates to detect systemic issues +- Use async transports for better concurrency when handling multiple requests + +For more advanced retry configurations, refer to the [tenacity documentation](https://tenacity.readthedocs.io/). diff --git a/mkdocs.yml b/mkdocs.yml index 1860a00c0e..6ce47a1a9b 100644 --- a/mkdocs.yml +++ b/mkdocs.yml @@ -43,6 +43,7 @@ nav: - thinking.md - direct.md - common-tools.md + - retries.md - MCP: - mcp/index.md - mcp/client.md @@ -101,6 +102,7 @@ nav: - api/models/mcp-sampling.md - api/profiles.md - api/providers.md + - api/retries.md - api/pydantic_graph/graph.md - api/pydantic_graph/nodes.md - api/pydantic_graph/persistence.md diff --git a/pydantic_ai_slim/pydantic_ai/retries.py b/pydantic_ai_slim/pydantic_ai/retries.py new file mode 100644 index 0000000000..c82378ac60 --- /dev/null +++ b/pydantic_ai_slim/pydantic_ai/retries.py @@ -0,0 +1,249 @@ +"""Retries utilities based on tenacity, especially for HTTP requests. + +This module provides HTTP transport wrappers and wait strategies that integrate with +the tenacity library to add retry capabilities to HTTP requests. The transports can be +used with HTTP clients that support custom transports (such as httpx), while the wait +strategies can be used with any tenacity retry decorator. + +The module includes: +- TenacityTransport: Synchronous HTTP transport with retry capabilities +- AsyncTenacityTransport: Asynchronous HTTP transport with retry capabilities +- wait_retry_after: Wait strategy that respects HTTP Retry-After headers +""" + +from __future__ import annotations + +from httpx import AsyncBaseTransport, AsyncHTTPTransport, BaseTransport, HTTPTransport, Request, Response + +try: + from tenacity import AsyncRetrying, Retrying +except ImportError as _import_error: + raise ImportError( + 'Please install `tenacity` to use the retries utilities, ' + 'you can use the `retries` optional group — `pip install "pydantic-ai-slim[retries]"`' + ) from _import_error + + +__all__ = ['TenacityTransport', 'AsyncTenacityTransport', 'wait_retry_after'] + +from datetime import datetime, timezone +from email.utils import parsedate_to_datetime +from typing import Callable, cast + +from httpx import HTTPStatusError +from tenacity import RetryCallState, wait_exponential + + +class TenacityTransport(BaseTransport): + """Synchronous HTTP transport with tenacity-based retry functionality. + + This transport wraps another BaseTransport and adds retry capabilities using the tenacity library. + It can be configured to retry requests based on various conditions such as specific exception types, + response status codes, or custom validation logic. + + The transport works by intercepting HTTP requests and responses, allowing the tenacity controller + to determine when and how to retry failed requests. The validate_response function can be used + to convert HTTP responses into exceptions that trigger retries. + + Args: + wrapped: The underlying transport to wrap and add retry functionality to. + controller: The tenacity Retrying instance that defines the retry behavior + (retry conditions, wait strategy, stop conditions, etc.). + validate_response: Optional callable that takes a Response and can raise an exception + to be handled by the controller if the response should trigger a retry. + Common use case is to raise exceptions for certain HTTP status codes. + If None, no response validation is performed. + + Example: + ```python + from httpx import Client, HTTPTransport, HTTPStatusError + from tenacity import Retrying, stop_after_attempt, retry_if_exception_type + from pydantic_ai.retries import TenacityTransport, wait_retry_after + + transport = TenacityTransport( + HTTPTransport(), + Retrying( + retry=retry_if_exception_type(HTTPStatusError), + wait=wait_retry_after(max_wait=300), + stop=stop_after_attempt(5), + reraise=True + ), + validate_response=lambda r: r.raise_for_status() + ) + client = Client(transport=transport) + ``` + """ + + def __init__( + self, + controller: Retrying, + wrapped: BaseTransport | None = None, + validate_response: Callable[[Response], None] | None = None, + ): + self.controller = controller + self.wrapped = wrapped or HTTPTransport() + self.validate_response = validate_response + + def handle_request(self, request: Request) -> Response: + """Handle an HTTP request with retry logic. + + Args: + request: The HTTP request to handle. + + Returns: + The HTTP response. + + Raises: + RuntimeError: If the retry controller did not make any attempts. + Exception: Any exception raised by the wrapped transport or validation function. + """ + for attempt in self.controller: + with attempt: + response = self.wrapped.handle_request(request) + if self.validate_response: + self.validate_response(response) + return response + raise RuntimeError('The retry controller did not make any attempts') # pragma: no cover + + +class AsyncTenacityTransport(AsyncBaseTransport): + """Asynchronous HTTP transport with tenacity-based retry functionality. + + This transport wraps another AsyncBaseTransport and adds retry capabilities using the tenacity library. + It can be configured to retry requests based on various conditions such as specific exception types, + response status codes, or custom validation logic. + + The transport works by intercepting HTTP requests and responses, allowing the tenacity controller + to determine when and how to retry failed requests. The validate_response function can be used + to convert HTTP responses into exceptions that trigger retries. + + Args: + wrapped: The underlying async transport to wrap and add retry functionality to. + controller: The tenacity AsyncRetrying instance that defines the retry behavior + (retry conditions, wait strategy, stop conditions, etc.). + validate_response: Optional callable that takes a Response and can raise an exception + to be handled by the controller if the response should trigger a retry. + Common use case is to raise exceptions for certain HTTP status codes. + If None, no response validation is performed. + + Example: + ```python + from httpx import AsyncClient, HTTPStatusError + from tenacity import AsyncRetrying, stop_after_attempt, retry_if_exception_type + from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after + + transport = AsyncTenacityTransport( + AsyncRetrying( + retry=retry_if_exception_type(HTTPStatusError), + wait=wait_retry_after(max_wait=300), + stop=stop_after_attempt(5), + reraise=True + ), + validate_response=lambda r: r.raise_for_status() + ) + client = AsyncClient(transport=transport) + ``` + """ + + def __init__( + self, + controller: AsyncRetrying, + wrapped: AsyncBaseTransport | None = None, + validate_response: Callable[[Response], None] | None = None, + ): + self.controller = controller + self.wrapped = wrapped or AsyncHTTPTransport() + self.validate_response = validate_response + + async def handle_async_request(self, request: Request) -> Response: + """Handle an async HTTP request with retry logic. + + Args: + request: The HTTP request to handle. + + Returns: + The HTTP response. + + Raises: + RuntimeError: If the retry controller did not make any attempts. + Exception: Any exception raised by the wrapped transport or validation function. + """ + async for attempt in self.controller: + with attempt: + response = await self.wrapped.handle_async_request(request) + if self.validate_response: + self.validate_response(response) + return response + raise RuntimeError('The retry controller did not make any attempts') # pragma: no cover + + +def wait_retry_after( + fallback_strategy: Callable[[RetryCallState], float] | None = None, max_wait: float = 300 +) -> Callable[[RetryCallState], float]: + """Create a tenacity-compatible wait strategy that respects HTTP Retry-After headers. + + This wait strategy checks if the exception contains an HTTPStatusError with a + Retry-After header, and if so, waits for the time specified in the header. + If no header is present or parsing fails, it falls back to the provided strategy. + + The Retry-After header can be in two formats: + - An integer representing seconds to wait + - An HTTP date string representing when to retry + + Args: + fallback_strategy: Wait strategy to use when no Retry-After header is present + or parsing fails. Defaults to exponential backoff with max 60s. + max_wait: Maximum time to wait in seconds, regardless of header value. + Defaults to 300 (5 minutes). + + Returns: + A wait function that can be used with tenacity retry decorators. + + Example: + ```python + from httpx import AsyncClient, HTTPStatusError + from tenacity import AsyncRetrying, stop_after_attempt, retry_if_exception_type + from pydantic_ai.retries import AsyncTenacityTransport, wait_retry_after + + transport = AsyncTenacityTransport( + AsyncRetrying( + retry=retry_if_exception_type(HTTPStatusError), + wait=wait_retry_after(max_wait=120), + stop=stop_after_attempt(5), + reraise=True + ), + validate_response=lambda r: r.raise_for_status() + ) + client = AsyncClient(transport=transport) + ``` + """ + if fallback_strategy is None: + fallback_strategy = wait_exponential(multiplier=1, max=60) + + def wait_func(state: RetryCallState) -> float: + exc = state.outcome.exception() if state.outcome else None + if isinstance(exc, HTTPStatusError): + retry_after = exc.response.headers.get('retry-after') + if retry_after: + try: + # Try parsing as seconds first + wait_seconds = int(retry_after) + return min(float(wait_seconds), max_wait) + except ValueError: + # Try parsing as HTTP date + try: + retry_time = cast(datetime, parsedate_to_datetime(retry_after)) + assert isinstance(retry_time, datetime) + now = datetime.now(timezone.utc) + wait_seconds = (retry_time - now).total_seconds() + + if wait_seconds > 0: + return min(wait_seconds, max_wait) + except (ValueError, TypeError, AssertionError): + # If date parsing fails, fall back to fallback strategy + pass + + # Use fallback strategy + return fallback_strategy(state) + + return wait_func diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 0e08012323..7407e14e2a 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -84,6 +84,8 @@ evals = ["pydantic-evals=={{ version }}"] a2a = ["fasta2a>=0.4.1"] # AG-UI ag-ui = ["ag-ui-protocol>=0.1.8", "starlette>=0.45.3"] +# Retries +retries = ["tenacity>=8.2.3"] [dependency-groups] dev = [ diff --git a/pyproject.toml b/pyproject.toml index 841f186ef2..9079ae58d8 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ requires-python = ">=3.9" [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals,ag-ui,retries]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] diff --git a/tests/test_tenacity.py b/tests/test_tenacity.py new file mode 100644 index 0000000000..c580543b43 --- /dev/null +++ b/tests/test_tenacity.py @@ -0,0 +1,572 @@ +from __future__ import annotations as _annotations + +import asyncio +import time +from datetime import datetime, timezone +from email.utils import formatdate +from unittest.mock import AsyncMock, Mock + +import httpx +import pytest + +from .conftest import try_import + +with try_import() as imports_successful: + from tenacity import ( + AsyncRetrying, + RetryCallState, + Retrying, + retry_if_exception_type, + stop_after_attempt, + wait_fixed, + ) + + from pydantic_ai.retries import AsyncTenacityTransport, TenacityTransport, wait_retry_after + +pytestmark = pytest.mark.skipif(not imports_successful(), reason='install tenacity to run tenacity tests') + + +class TestTenacityTransport: + """Tests for the synchronous TenacityTransport.""" + + def test_successful_request(self): + """Test that successful requests pass through without retry.""" + mock_transport = Mock(spec=httpx.BaseTransport) + mock_response = Mock(spec=httpx.Response) + mock_transport.handle_request.return_value = mock_response + + controller = Retrying(stop=stop_after_attempt(3), reraise=True) + transport = TenacityTransport(controller, mock_transport) + + request = httpx.Request('GET', 'https://example.com') + result = transport.handle_request(request) + + assert result is mock_response + mock_transport.handle_request.assert_called_once_with(request) + + def test_retry_on_exception(self): + """Test that exceptions trigger retries.""" + mock_transport = Mock(spec=httpx.BaseTransport) + mock_response = Mock(spec=httpx.Response) + + # Fail twice, succeed on third attempt + mock_transport.handle_request.side_effect = [ + httpx.ConnectError('Connection failed'), + httpx.ConnectError('Connection failed again'), + mock_response, + ] + + controller = Retrying( + retry=retry_if_exception_type(httpx.ConnectError), + stop=stop_after_attempt(3), + wait=wait_fixed(0.001), # Very short wait for tests + reraise=True, + ) + transport = TenacityTransport(controller, mock_transport) + + request = httpx.Request('GET', 'https://example.com') + result = transport.handle_request(request) + + assert result is mock_response + assert mock_transport.handle_request.call_count == 3 + + def test_retry_exhausted(self): + """Test that retry exhaustion re-raises the last exception.""" + mock_transport = Mock(spec=httpx.BaseTransport) + mock_transport.handle_request.side_effect = httpx.ConnectError('Connection failed') + + controller = Retrying( + retry=retry_if_exception_type(httpx.ConnectError), + stop=stop_after_attempt(2), + wait=wait_fixed(0.001), + reraise=True, + ) + transport = TenacityTransport(controller, mock_transport) + + request = httpx.Request('GET', 'https://example.com') + with pytest.raises(httpx.ConnectError, match='Connection failed'): + transport.handle_request(request) + + assert mock_transport.handle_request.call_count == 2 + + def test_validate_response_success(self): + """Test that validate_response is called and doesn't raise.""" + mock_transport = Mock(spec=httpx.BaseTransport) + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_transport.handle_request.return_value = mock_response + + validate_response = Mock() + controller = Retrying(stop=stop_after_attempt(3), reraise=True) + transport = TenacityTransport(controller, mock_transport, validate_response) + + request = httpx.Request('GET', 'https://example.com') + result = transport.handle_request(request) + + assert result is mock_response + validate_response.assert_called_once_with(mock_response) + + def test_validate_response_triggers_retry(self): + """Test that validate_response can trigger retries.""" + mock_transport = Mock(spec=httpx.BaseTransport) + mock_response_fail = Mock(spec=httpx.Response) + mock_response_fail.status_code = 429 + mock_response_success = Mock(spec=httpx.Response) + mock_response_success.status_code = 200 + + mock_transport.handle_request.side_effect = [mock_response_fail, mock_response_success] + + def validate_response(response: httpx.Response): + if response.status_code == 429: + raise httpx.HTTPStatusError('Rate limited', request=request, response=response) + + controller = Retrying( + retry=retry_if_exception_type(httpx.HTTPStatusError), + stop=stop_after_attempt(3), + wait=wait_fixed(0.001), + reraise=True, + ) + transport = TenacityTransport(controller, mock_transport, validate_response) + + request = httpx.Request('GET', 'https://example.com') + result = transport.handle_request(request) + + assert result is mock_response_success + assert mock_transport.handle_request.call_count == 2 + + +class TestAsyncTenacityTransport: + """Tests for the asynchronous AsyncTenacityTransport.""" + + async def test_successful_request(self): + """Test that successful requests pass through without retry.""" + mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport) + mock_response = Mock(spec=httpx.Response) + mock_transport.handle_async_request.return_value = mock_response + + controller = AsyncRetrying(stop=stop_after_attempt(3), reraise=True) + transport = AsyncTenacityTransport(controller, mock_transport) + + request = httpx.Request('GET', 'https://example.com') + result = await transport.handle_async_request(request) + + assert result is mock_response + mock_transport.handle_async_request.assert_called_once_with(request) + + async def test_retry_on_exception(self): + """Test that exceptions trigger retries.""" + mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport) + mock_response = Mock(spec=httpx.Response) + + # Fail twice, succeed on third attempt + mock_transport.handle_async_request.side_effect = [ + httpx.ConnectError('Connection failed'), + httpx.ConnectError('Connection failed again'), + mock_response, + ] + + controller = AsyncRetrying( + retry=retry_if_exception_type(httpx.ConnectError), + stop=stop_after_attempt(3), + wait=wait_fixed(0.001), + reraise=True, + ) + transport = AsyncTenacityTransport(controller, mock_transport) + + request = httpx.Request('GET', 'https://example.com') + result = await transport.handle_async_request(request) + + assert result is mock_response + assert mock_transport.handle_async_request.call_count == 3 + + async def test_retry_exhausted(self): + """Test that retry exhaustion re-raises the last exception.""" + mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport) + mock_transport.handle_async_request.side_effect = httpx.ConnectError('Connection failed') + + controller = AsyncRetrying( + retry=retry_if_exception_type(httpx.ConnectError), + stop=stop_after_attempt(2), + wait=wait_fixed(0.001), + reraise=True, + ) + transport = AsyncTenacityTransport(controller, mock_transport) + + request = httpx.Request('GET', 'https://example.com') + with pytest.raises(httpx.ConnectError, match='Connection failed'): + await transport.handle_async_request(request) + + assert mock_transport.handle_async_request.call_count == 2 + + async def test_validate_response_success(self): + """Test that validate_response is called and doesn't raise.""" + mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport) + mock_response = Mock(spec=httpx.Response) + mock_response.status_code = 200 + mock_transport.handle_async_request.return_value = mock_response + + validate_response = Mock() + controller = AsyncRetrying(stop=stop_after_attempt(3), reraise=True) + transport = AsyncTenacityTransport(controller, mock_transport, validate_response) + + request = httpx.Request('GET', 'https://example.com') + result = await transport.handle_async_request(request) + + assert result is mock_response + validate_response.assert_called_once_with(mock_response) + + async def test_validate_response_triggers_retry(self): + """Test that validate_response can trigger retries.""" + mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport) + mock_response_fail = Mock(spec=httpx.Response) + mock_response_fail.status_code = 429 + mock_response_success = Mock(spec=httpx.Response) + mock_response_success.status_code = 200 + + mock_transport.handle_async_request.side_effect = [mock_response_fail, mock_response_success] + + def validate_response(response: httpx.Response): + if response.status_code == 429: + raise httpx.HTTPStatusError('Rate limited', request=request, response=response) + + controller = AsyncRetrying( + retry=retry_if_exception_type(httpx.HTTPStatusError), + stop=stop_after_attempt(3), + wait=wait_fixed(0.001), + reraise=True, + ) + transport = AsyncTenacityTransport(controller, mock_transport, validate_response) + + request = httpx.Request('GET', 'https://example.com') + result = await transport.handle_async_request(request) + + assert result is mock_response_success + assert mock_transport.handle_async_request.call_count == 2 + + +class TestWaitRetryAfter: + """Tests for the wait_retry_after wait strategy.""" + + def test_no_exception_uses_fallback(self): + """Test that fallback strategy is used when there's no exception.""" + fallback = Mock(return_value=5.0) + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create a retry state with no exception + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = None + + result = wait_func(retry_state) + + assert result == 5.0 + fallback.assert_called_once_with(retry_state) + + def test_non_http_exception_uses_fallback(self): + """Test that fallback strategy is used for non-HTTP exceptions.""" + fallback = Mock(return_value=3.0) + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create a retry state with a non-HTTP exception + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = ValueError('Some error') + + result = wait_func(retry_state) + + assert result == 3.0 + fallback.assert_called_once_with(retry_state) + + def test_http_exception_no_retry_after_uses_fallback(self): + """Test that fallback strategy is used when there's no Retry-After header.""" + fallback = Mock(return_value=2.0) + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create HTTP status error without Retry-After header + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 2.0 + fallback.assert_called_once_with(retry_state) + + def test_retry_after_seconds_format(self): + """Test parsing Retry-After header in seconds format.""" + fallback = Mock() + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create HTTP status error with Retry-After in seconds + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': '30'} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 30.0 + fallback.assert_not_called() + + def test_retry_after_seconds_respects_max_wait(self): + """Test that max_wait is respected for seconds format.""" + fallback = Mock() + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=60) + + # Create HTTP status error with Retry-After > max_wait + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': '120'} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 60.0 # Capped at max_wait + fallback.assert_not_called() + + def test_retry_after_http_date_format(self): + """Test parsing Retry-After header in HTTP date format.""" + fallback = Mock() + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create a future date (30 seconds from now) + future_time = datetime.now(timezone.utc).timestamp() + 30 + http_date = formatdate(future_time, usegmt=True) + + # Create HTTP status error with Retry-After in HTTP date format + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': http_date} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + # Should be approximately 30 seconds (allow some tolerance for test timing) + assert 25 <= result <= 35 + fallback.assert_not_called() + + def test_retry_after_http_date_past_time_uses_fallback(self): + """Test that past dates in Retry-After fall back to fallback strategy.""" + fallback = Mock(return_value=1.0) + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create a past date + past_time = datetime.now(timezone.utc).timestamp() - 30 + http_date = formatdate(past_time, usegmt=True) + + # Create HTTP status error with Retry-After in HTTP date format + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': http_date} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 1.0 + fallback.assert_called_once_with(retry_state) + + def test_retry_after_http_date_respects_max_wait(self): + """Test that max_wait is respected for HTTP date format.""" + fallback = Mock() + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=60) + + # Create a future date (120 seconds from now, > max_wait) + future_time = datetime.now(timezone.utc).timestamp() + 120 + http_date = formatdate(future_time, usegmt=True) + + # Create HTTP status error with Retry-After in HTTP date format + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': http_date} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 60.0 # Capped at max_wait + fallback.assert_not_called() + + def test_retry_after_invalid_format_uses_fallback(self): + """Test that invalid Retry-After values fall back to fallback strategy.""" + fallback = Mock(return_value=4.0) + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create HTTP status error with invalid Retry-After + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': 'invalid-value'} + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 4.0 + fallback.assert_called_once_with(retry_state) + + def test_default_fallback_strategy(self): + """Test that default fallback strategy is used when none is provided.""" + wait_func = wait_retry_after(max_wait=300) + + # Create a retry state with no exception to trigger fallback + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = None + retry_state.attempt_number = 1 + + # Should use default exponential backoff, exact value depends on retry state + result = wait_func(retry_state) + + assert result == 1 # first backoff + + def test_default_max_wait(self): + """Test that default max_wait of 300 seconds is used.""" + wait_func = wait_retry_after() # Use all defaults + + # Create HTTP status error with large Retry-After value + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + response.headers = {'retry-after': '600'} # 10 minutes + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 300.0 # Capped at default max_wait + + def test_case_insensitive_header_access(self): + """Test that Retry-After header access is case insensitive.""" + fallback = Mock() + wait_func = wait_retry_after(fallback_strategy=fallback, max_wait=300) + + # Create HTTP status error with uppercase Retry-After header + request = httpx.Request('GET', 'https://example.com') + response = Mock(spec=httpx.Response) + # httpx headers are case-insensitive, so this should work + response.headers = httpx.Headers({'Retry-After': '45'}) + http_error = httpx.HTTPStatusError('Rate limited', request=request, response=response) + + retry_state = Mock(spec=RetryCallState) + retry_state.outcome = Mock() + retry_state.outcome.failed = True + retry_state.outcome.exception.return_value = http_error + + result = wait_func(retry_state) + + assert result == 45.0 + fallback.assert_not_called() + + +class TestIntegration: + """Integration tests combining transports with wait strategies.""" + + async def test_async_transport_with_wait_retry_after(self): + """Test AsyncTenacityTransport with wait_retry_after strategy.""" + mock_transport = AsyncMock(spec=httpx.AsyncBaseTransport) + mock_response_fail = Mock(spec=httpx.Response) + mock_response_fail.status_code = 429 + mock_response_fail.headers = {'retry-after': '1'} + mock_response_success = Mock(spec=httpx.Response) + mock_response_success.status_code = 200 + + mock_transport.handle_async_request.side_effect = [mock_response_fail, mock_response_success] + + # Track validation calls + validation_calls: list[int] = [] + + def validate_response(response: httpx.Response): + validation_calls.append(response.status_code) + if response.status_code == 429: + raise httpx.HTTPStatusError('Rate limited', request=request, response=response) + + controller = AsyncRetrying( + retry=retry_if_exception_type(httpx.HTTPStatusError), + wait=wait_retry_after(max_wait=5), # Short max_wait for tests + stop=stop_after_attempt(3), + reraise=True, + ) + transport = AsyncTenacityTransport(controller, mock_transport, validate_response) + + request = httpx.Request('GET', 'https://example.com') + + # Time the request to ensure retry-after wait was respected + start_time = asyncio.get_event_loop().time() + result = await transport.handle_async_request(request) + end_time = asyncio.get_event_loop().time() + + assert result is mock_response_success + assert mock_transport.handle_async_request.call_count == 2 + assert validation_calls == [429, 200] # First call failed, second succeeded + # Should have waited approximately 1 second (allow some tolerance) + assert 0.8 <= (end_time - start_time) <= 2.0 + + def test_sync_transport_with_wait_retry_after(self): + """Test TenacityTransport with wait_retry_after strategy.""" + mock_transport = Mock(spec=httpx.BaseTransport) + mock_response_fail = Mock(spec=httpx.Response) + mock_response_fail.status_code = 429 + mock_response_fail.headers = {'retry-after': '30'} # 30 seconds, will be capped + mock_response_success = Mock(spec=httpx.Response) + mock_response_success.status_code = 200 + + mock_transport.handle_request.side_effect = [mock_response_fail, mock_response_success] + + def validate_response(response: httpx.Response): + if response.status_code == 429: + raise httpx.HTTPStatusError('Rate limited', request=request, response=response) + + controller = Retrying( + retry=retry_if_exception_type(httpx.HTTPStatusError), + wait=wait_retry_after(max_wait=2), # Cap at 2 seconds for tests + stop=stop_after_attempt(3), + reraise=True, + ) + transport = TenacityTransport(controller, mock_transport, validate_response) + + request = httpx.Request('GET', 'https://example.com') + + # Time the request to ensure max_wait was respected + start_time = time.time() + result = transport.handle_request(request) + end_time = time.time() + + assert result is mock_response_success + assert mock_transport.handle_request.call_count == 2 + # Should have waited approximately 2 seconds (capped by max_wait) + assert 1.8 <= (end_time - start_time) <= 3.0 diff --git a/uv.lock b/uv.lock index e4d40ca65c..73e1ef90c1 100644 --- a/uv.lock +++ b/uv.lock @@ -3021,7 +3021,7 @@ wheels = [ name = "pydantic-ai" source = { editable = "." } dependencies = [ - { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"] }, + { name = "pydantic-ai-slim", extra = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "retries", "vertexai"] }, ] [package.optional-dependencies] @@ -3060,7 +3060,7 @@ requires-dist = [ { name = "fasta2a", marker = "extra == 'a2a'", specifier = ">=0.4.1" }, { name = "logfire", marker = "extra == 'logfire'", specifier = ">=3.11.0" }, { name = "pydantic-ai-examples", marker = "extra == 'examples'", editable = "examples" }, - { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "vertexai"], editable = "pydantic_ai_slim" }, + { name = "pydantic-ai-slim", extras = ["ag-ui", "anthropic", "bedrock", "cli", "cohere", "evals", "google", "groq", "huggingface", "mcp", "mistral", "openai", "retries", "vertexai"], editable = "pydantic_ai_slim" }, ] provides-extras = ["a2a", "examples", "logfire"] @@ -3181,6 +3181,9 @@ mistral = [ openai = [ { name = "openai" }, ] +retries = [ + { name = "tenacity" }, +] tavily = [ { name = "tavily-python" }, ] @@ -3240,9 +3243,10 @@ requires-dist = [ { name = "rich", marker = "extra == 'cli'", specifier = ">=13" }, { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, + { name = "tenacity", marker = "extra == 'retries'", specifier = ">=8.2.3" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] -provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "tavily", "vertexai"] +provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "tavily", "vertexai"] [package.metadata.requires-dev] dev = [ From 6e60677836a931a399d89e2840c492ed39bc451c Mon Sep 17 00:00:00 2001 From: Aditya Vardhan <76904033+adtyavrdhn@users.noreply.github.com> Date: Fri, 25 Jul 2025 23:04:49 +0530 Subject: [PATCH 75/89] Adding thinkingpart to otel_events in ModelResponse (#2237) Co-authored-by: Alex Hall --- pydantic_ai_slim/pydantic_ai/messages.py | 15 ++++-- tests/models/test_instrumented.py | 64 ++++++++++++++++-------- 2 files changed, 54 insertions(+), 25 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 379d70efd7..51e63eea5b 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -815,11 +815,16 @@ def new_event_body(): }, } ) - elif isinstance(part, TextPart): - if body.get('content'): - body = new_event_body() - if settings.include_content: - body['content'] = part.content + elif isinstance(part, (TextPart, ThinkingPart)): + kind = part.part_kind + body.setdefault('content', []).append( + {'kind': kind, **({'text': part.content} if settings.include_content else {})} + ) + + if content := body.get('content'): + text_content = content[0].get('text') + if content == [{'kind': 'text', 'text': text_content}]: + body['content'] = text_content return result diff --git a/tests/models/test_instrumented.py b/tests/models/test_instrumented.py index c926e1c3aa..a156bb7fa8 100644 --- a/tests/models/test_instrumented.py +++ b/tests/models/test_instrumented.py @@ -26,6 +26,7 @@ SystemPromptPart, TextPart, TextPartDelta, + ThinkingPart, ToolCallPart, ToolReturnPart, UserPromptPart, @@ -150,7 +151,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'context': {'trace_id': 1, 'span_id': 1, 'is_remote': False}, 'parent': None, 'start_time': 1000000000, - 'end_time': 18000000000, + 'end_time': 16000000000, 'attributes': { 'gen_ai.operation.name': 'chat', 'gen_ai.system': 'my_system', @@ -284,7 +285,7 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'index': 0, 'message': { 'role': 'assistant', - 'content': 'text1', + 'content': [{'kind': 'text', 'text': 'text1'}, {'kind': 'text', 'text': 'text2'}], 'tool_calls': [ { 'id': 'tool_call_1', @@ -308,17 +309,6 @@ async def test_instrumented_model(capfire: CaptureLogfire): 'span_id': 1, 'trace_flags': 1, }, - { - 'body': {'index': 0, 'message': {'role': 'assistant', 'content': 'text2'}}, - 'severity_number': 9, - 'severity_text': None, - 'attributes': {'gen_ai.system': 'my_system', 'event.name': 'gen_ai.choice'}, - 'timestamp': 16000000000, - 'observed_timestamp': 17000000000, - 'trace_id': 1, - 'span_id': 1, - 'trace_flags': 1, - }, ] ) @@ -641,11 +631,13 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): 'gen_ai.system': 'my_system', }, { - 'event.name': 'gen_ai.choice', 'index': 0, 'message': { 'role': 'assistant', - 'content': 'text1', + 'content': [ + {'kind': 'text', 'text': 'text1'}, + {'kind': 'text', 'text': 'text2'}, + ], 'tool_calls': [ { 'id': 'tool_call_1', @@ -660,12 +652,7 @@ async def test_instrumented_model_attributes_mode(capfire: CaptureLogfire): ], }, 'gen_ai.system': 'my_system', - }, - { 'event.name': 'gen_ai.choice', - 'index': 0, - 'message': {'role': 'assistant', 'content': 'text2'}, - 'gen_ai.system': 'my_system', }, ] ) @@ -879,6 +866,7 @@ def test_messages_without_content(document_content: BinaryContent): }, { 'role': 'assistant', + 'content': [{'kind': 'text'}], 'gen_ai.message.index': 1, 'event.name': 'gen_ai.assistant.message', }, @@ -897,6 +885,7 @@ def test_messages_without_content(document_content: BinaryContent): }, { 'role': 'assistant', + 'content': [{'kind': 'text'}], 'tool_calls': [ { 'id': IsStr(), @@ -935,3 +924,38 @@ def test_messages_without_content(document_content: BinaryContent): }, ] ) + + +def test_message_with_thinking_parts(): + messages: list[ModelMessage] = [ + ModelResponse(parts=[TextPart('text1'), ThinkingPart('thinking1'), TextPart('text2')]), + ModelResponse(parts=[ThinkingPart('thinking2')]), + ModelResponse(parts=[ThinkingPart('thinking3'), TextPart('text3')]), + ] + settings = InstrumentationSettings() + assert [InstrumentedModel.event_to_dict(e) for e in settings.messages_to_otel_events(messages)] == snapshot( + [ + { + 'role': 'assistant', + 'content': [ + {'kind': 'text', 'text': 'text1'}, + {'kind': 'thinking', 'text': 'thinking1'}, + {'kind': 'text', 'text': 'text2'}, + ], + 'gen_ai.message.index': 0, + 'event.name': 'gen_ai.assistant.message', + }, + { + 'role': 'assistant', + 'content': [{'kind': 'thinking', 'text': 'thinking2'}], + 'gen_ai.message.index': 1, + 'event.name': 'gen_ai.assistant.message', + }, + { + 'role': 'assistant', + 'content': [{'kind': 'thinking', 'text': 'thinking3'}, {'kind': 'text', 'text': 'text3'}], + 'gen_ai.message.index': 2, + 'event.name': 'gen_ai.assistant.message', + }, + ] + ) From 6207ac6f0ef06b8ce4685173aa441c5298471829 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Fri, 25 Jul 2025 13:35:50 -0600 Subject: [PATCH 76/89] Add Claude Code action so we can tag @claude on issues and PRs (#2315) --- .github/workflows/claude.yml | 36 ++++++++++++++++++++++++++++++++++++ 1 file changed, 36 insertions(+) create mode 100644 .github/workflows/claude.yml diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml new file mode 100644 index 0000000000..d4a716b7ff --- /dev/null +++ b/.github/workflows/claude.yml @@ -0,0 +1,36 @@ +name: Claude PR Assistant + +on: + issue_comment: + types: [created] + pull_request_review_comment: + types: [created] + issues: + types: [opened, assigned] + pull_request_review: + types: [submitted] + +jobs: + claude-code-action: + if: | + (github.event_name == 'issue_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review_comment' && contains(github.event.comment.body, '@claude')) || + (github.event_name == 'pull_request_review' && contains(github.event.review.body, '@claude')) || + (github.event_name == 'issues' && contains(github.event.issue.body, '@claude')) + runs-on: ubuntu-latest + permissions: + contents: read + pull-requests: read + issues: read + id-token: write + steps: + - name: Checkout repository + uses: actions/checkout@v4 + with: + fetch-depth: 1 + + - name: Run Claude PR Action + uses: anthropics/claude-code-action@beta + with: + anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} + timeout_minutes: "60" From 806d56db8f52d25ac8271a448c24591b5a57579d Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Niko=20D=C3=BCrr?= <48563891+tradeqvest@users.noreply.github.com> Date: Mon, 28 Jul 2025 15:21:21 +0200 Subject: [PATCH 77/89] Fix: TypeError in MCPServerSSE due to improper initialization (#2319) --- pydantic_ai_slim/pydantic_ai/mcp.py | 1 + 1 file changed, 1 insertion(+) diff --git a/pydantic_ai_slim/pydantic_ai/mcp.py b/pydantic_ai_slim/pydantic_ai/mcp.py index 77d53f0800..eb3e13ff5f 100644 --- a/pydantic_ai_slim/pydantic_ai/mcp.py +++ b/pydantic_ai_slim/pydantic_ai/mcp.py @@ -543,6 +543,7 @@ def __init__( self.max_retries = max_retries self.sampling_model = sampling_model self.read_timeout = read_timeout + self.__post_init__() @property @abstractmethod From 33aaef1e17078636078919aa4e360738de7b2abf Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 28 Jul 2025 08:30:42 -0600 Subject: [PATCH 78/89] Pin tokenizers <= 0.21.2 as later 0.21.4 doesn't have required wheels (#2335) --- pydantic_ai_slim/pyproject.toml | 7 ++++++- uv.lock | 2 ++ 2 files changed, 8 insertions(+), 1 deletion(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 7407e14e2a..45df89a2c6 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -63,7 +63,12 @@ dependencies = [ logfire = ["logfire>=3.11.0"] # Models openai = ["openai>=1.92.0"] -cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"] +cohere = [ + "cohere>=5.16.0; platform_system != 'Emscripten'", + # Remove once all wheels for 0.21.4+ are built successfully + # https://github.com/huggingface/tokenizers/actions/runs/16570140346/job/46860152621 + "tokenizers<=0.21.2", +] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.24.0"] anthropic = ["anthropic>=0.52.0"] diff --git a/uv.lock b/uv.lock index 73e1ef90c1..b472a153ba 100644 --- a/uv.lock +++ b/uv.lock @@ -3153,6 +3153,7 @@ cli = [ ] cohere = [ { name = "cohere", marker = "sys_platform != 'emscripten'" }, + { name = "tokenizers" }, ] duckduckgo = [ { name = "ddgs" }, @@ -3244,6 +3245,7 @@ requires-dist = [ { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "tenacity", marker = "extra == 'retries'", specifier = ">=8.2.3" }, + { name = "tokenizers", marker = "extra == 'cohere'", specifier = "<=0.21.2" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "tavily", "vertexai"] From 86d70b54586792ff66cf8674a0bffcc6a1dc5530 Mon Sep 17 00:00:00 2001 From: Charlie Jonas Date: Mon, 28 Jul 2025 08:40:52 -0600 Subject: [PATCH 79/89] Fix: AG-UI assistant text and tool call order (#2328) Co-authored-by: Douwe Maan --- pydantic_ai_slim/pydantic_ai/ag_ui.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 2dbe6faf3a..416fe627ed 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -486,6 +486,9 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: if isinstance(msg, UserMessage): result.append(ModelRequest(parts=[UserPromptPart(content=msg.content)])) elif isinstance(msg, AssistantMessage): + if msg.content: + result.append(ModelResponse(parts=[TextPart(content=msg.content)])) + if msg.tool_calls: for tool_call in msg.tool_calls: tool_calls[tool_call.id] = tool_call.function.name @@ -502,9 +505,6 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: ] ) ) - - if msg.content: - result.append(ModelResponse(parts=[TextPart(content=msg.content)])) elif isinstance(msg, SystemMessage): result.append(ModelRequest(parts=[SystemPromptPart(content=msg.content)])) elif isinstance(msg, ToolMessage): From 004d63bd3cb05b846c6fbe447176bbff8709f7b4 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 28 Jul 2025 11:37:45 -0600 Subject: [PATCH 80/89] Revert "Pin tokenizers <= 0.21.2 as later 0.21.4 doesn't have required wheels" (#2340) --- pydantic_ai_slim/pyproject.toml | 7 +------ uv.lock | 2 -- 2 files changed, 1 insertion(+), 8 deletions(-) diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 45df89a2c6..7407e14e2a 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -63,12 +63,7 @@ dependencies = [ logfire = ["logfire>=3.11.0"] # Models openai = ["openai>=1.92.0"] -cohere = [ - "cohere>=5.16.0; platform_system != 'Emscripten'", - # Remove once all wheels for 0.21.4+ are built successfully - # https://github.com/huggingface/tokenizers/actions/runs/16570140346/job/46860152621 - "tokenizers<=0.21.2", -] +cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.24.0"] anthropic = ["anthropic>=0.52.0"] diff --git a/uv.lock b/uv.lock index b472a153ba..73e1ef90c1 100644 --- a/uv.lock +++ b/uv.lock @@ -3153,7 +3153,6 @@ cli = [ ] cohere = [ { name = "cohere", marker = "sys_platform != 'emscripten'" }, - { name = "tokenizers" }, ] duckduckgo = [ { name = "ddgs" }, @@ -3245,7 +3244,6 @@ requires-dist = [ { name = "starlette", marker = "extra == 'ag-ui'", specifier = ">=0.45.3" }, { name = "tavily-python", marker = "extra == 'tavily'", specifier = ">=0.5.0" }, { name = "tenacity", marker = "extra == 'retries'", specifier = ">=8.2.3" }, - { name = "tokenizers", marker = "extra == 'cohere'", specifier = "<=0.21.2" }, { name = "typing-inspection", specifier = ">=0.4.0" }, ] provides-extras = ["a2a", "ag-ui", "anthropic", "bedrock", "cli", "cohere", "duckduckgo", "evals", "google", "groq", "huggingface", "logfire", "mcp", "mistral", "openai", "retries", "tavily", "vertexai"] From 0260a31b729301f220f019e4084309e35fa1b471 Mon Sep 17 00:00:00 2001 From: William Easton Date: Mon, 28 Jul 2025 14:14:43 -0500 Subject: [PATCH 81/89] Allow `default` in tool schema with Gemini (#2309) --- .../pydantic_ai/profiles/google.py | 1 - tests/models/test_gemini.py | 32 ++++++++++++++----- 2 files changed, 24 insertions(+), 9 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/profiles/google.py b/pydantic_ai_slim/pydantic_ai/profiles/google.py index 9178d7dd43..3859d76347 100644 --- a/pydantic_ai_slim/pydantic_ai/profiles/google.py +++ b/pydantic_ai_slim/pydantic_ai/profiles/google.py @@ -49,7 +49,6 @@ def transform(self, schema: JsonSchema) -> JsonSchema: ) schema.pop('title', None) - schema.pop('default', None) schema.pop('$schema', None) if (const := schema.pop('const', None)) is not None: # Gemini doesn't support const, but it does support enum with a single value diff --git a/tests/models/test_gemini.py b/tests/models/test_gemini.py index 0b95bedeb5..7d628df31a 100644 --- a/tests/models/test_gemini.py +++ b/tests/models/test_gemini.py @@ -194,7 +194,7 @@ async def test_require_response_tool(allow_model_requests: None): async def test_json_def_replaced(allow_model_requests: None): class Axis(BaseModel): - label: str + label: str = Field(default='', description='The label of the axis') class Chart(BaseModel): x_axis: Axis @@ -213,8 +213,14 @@ class Locations(BaseModel): { '$defs': { 'Axis': { - 'properties': {'label': {'title': 'Label', 'type': 'string'}}, - 'required': ['label'], + 'properties': { + 'label': { + 'default': '', + 'description': 'The label of the axis', + 'title': 'Label', + 'type': 'string', + } + }, 'title': 'Axis', 'type': 'object', }, @@ -268,17 +274,27 @@ class Locations(BaseModel): 'items': { 'properties': { 'lat': {'type': 'number'}, - 'lng': {'type': 'number'}, + 'lng': {'default': 1.1, 'type': 'number'}, 'chart': { 'properties': { 'x_axis': { - 'properties': {'label': {'type': 'string'}}, - 'required': ['label'], + 'properties': { + 'label': { + 'default': '', + 'description': 'The label of the axis', + 'type': 'string', + } + }, 'type': 'object', }, 'y_axis': { - 'properties': {'label': {'type': 'string'}}, - 'required': ['label'], + 'properties': { + 'label': { + 'default': '', + 'description': 'The label of the axis', + 'type': 'string', + } + }, 'type': 'object', }, }, From 168680aed867e593d653d9c694b9015a37b3cc0d Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 28 Jul 2025 17:03:00 -0600 Subject: [PATCH 82/89] Fix AgentStream.stream_output and StreamedRunResult.stream_structured with output tools (#2314) Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Douwe Maan --- pydantic_ai_slim/pydantic_ai/_tool_manager.py | 56 +++++++++++-------- pydantic_ai_slim/pydantic_ai/result.py | 8 ++- tests/test_streaming.py | 36 ++++++++++++ 3 files changed, 74 insertions(+), 26 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index bea4103896..657d32bde1 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -54,20 +54,25 @@ def get_tool_def(self, name: str) -> ToolDefinition | None: except KeyError: return None - async def handle_call(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + async def handle_call( + self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True + ) -> Any: """Handle a tool call by validating the arguments, calling the tool, and handling retries. Args: call: The tool call part to handle. allow_partial: Whether to allow partial validation of the tool arguments. + wrap_validation_errors: Whether to wrap validation errors in a retry prompt part. """ if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': # Output tool calls are not traced - return await self._call_tool(call, allow_partial) + return await self._call_tool(call, allow_partial, wrap_validation_errors) else: - return await self._call_tool_traced(call, allow_partial) + return await self._call_tool_traced(call, allow_partial, wrap_validation_errors) - async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + async def _call_tool( + self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True + ) -> Any: name = call.tool_name tool = self.tools.get(name) try: @@ -100,30 +105,35 @@ async def _call_tool(self, call: ToolCallPart, allow_partial: bool = False) -> A if current_retry == max_retries: raise UnexpectedModelBehavior(f'Tool {name!r} exceeded max retries count of {max_retries}') from e else: - if isinstance(e, ValidationError): - m = _messages.RetryPromptPart( - tool_name=name, - content=e.errors(include_url=False, include_context=False), - tool_call_id=call.tool_call_id, - ) - e = ToolRetryError(m) - elif isinstance(e, ModelRetry): - m = _messages.RetryPromptPart( - tool_name=name, - content=e.message, - tool_call_id=call.tool_call_id, - ) - e = ToolRetryError(m) - else: - assert_never(e) + if wrap_validation_errors: + if isinstance(e, ValidationError): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.errors(include_url=False, include_context=False), + tool_call_id=call.tool_call_id, + ) + e = ToolRetryError(m) + elif isinstance(e, ModelRetry): + m = _messages.RetryPromptPart( + tool_name=name, + content=e.message, + tool_call_id=call.tool_call_id, + ) + e = ToolRetryError(m) + else: + assert_never(e) + + if not allow_partial: + self.ctx.retries[name] = current_retry + 1 - self.ctx.retries[name] = current_retry + 1 raise e else: self.ctx.retries.pop(name, None) return output - async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = False) -> Any: + async def _call_tool_traced( + self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True + ) -> Any: """See .""" span_attributes = { 'gen_ai.tool.name': call.tool_name, @@ -152,7 +162,7 @@ async def _call_tool_traced(self, call: ToolCallPart, allow_partial: bool = Fals } with self.ctx.tracer.start_as_current_span('running tool', attributes=span_attributes) as span: try: - tool_result = await self._call_tool(call, allow_partial) + tool_result = await self._call_tool(call, allow_partial, wrap_validation_errors) except ToolRetryError as e: part = e.tool_retry if self.ctx.trace_include_content and span.is_recording(): diff --git a/pydantic_ai_slim/pydantic_ai/result.py b/pydantic_ai_slim/pydantic_ai/result.py index e640302b24..2dc3eb8259 100644 --- a/pydantic_ai_slim/pydantic_ai/result.py +++ b/pydantic_ai_slim/pydantic_ai/result.py @@ -67,7 +67,7 @@ async def stream_output(self, *, debounce_by: float | None = 0.1) -> AsyncIterat except ValidationError: pass if self._final_result_event is not None: # pragma: no branch - yield await self._validate_response(self._raw_stream_response.get(), allow_partial=False) + yield await self._validate_response(self._raw_stream_response.get()) async def stream_responses(self, *, debounce_by: float | None = 0.1) -> AsyncIterator[_messages.ModelResponse]: """Asynchronously stream the (unvalidated) model responses for the agent.""" @@ -128,7 +128,7 @@ async def get_output(self) -> OutputDataT: async for _ in self: pass - return await self._validate_response(self._raw_stream_response.get(), allow_partial=False) + return await self._validate_response(self._raw_stream_response.get()) async def _validate_response(self, message: _messages.ModelResponse, *, allow_partial: bool = False) -> OutputDataT: """Validate a structured result message.""" @@ -150,7 +150,9 @@ async def _validate_response(self, message: _messages.ModelResponse, *, allow_pa raise exceptions.UnexpectedModelBehavior( # pragma: no cover f'Invalid response, unable to find tool call for {output_tool_name!r}' ) - return await self._tool_manager.handle_call(tool_call, allow_partial=allow_partial) + return await self._tool_manager.handle_call( + tool_call, allow_partial=allow_partial, wrap_validation_errors=False + ) elif deferred_tool_calls := self._tool_manager.get_deferred_tool_calls(message.parts): if not self._output_schema.allows_deferred_tool_calls: raise exceptions.UserError( diff --git a/tests/test_streaming.py b/tests/test_streaming.py index dbdcd71f32..e8861a0e01 100644 --- a/tests/test_streaming.py +++ b/tests/test_streaming.py @@ -1108,6 +1108,42 @@ class CityLocation(BaseModel): ) +async def test_iter_stream_output_tool_dont_hit_retry_limit(): + class CityLocation(BaseModel): + city: str + country: str | None = None + + async def text_stream(_messages: list[ModelMessage], agent_info: AgentInfo) -> AsyncIterator[DeltaToolCalls]: + """Stream partial JSON data that will initially fail validation.""" + assert agent_info.output_tools is not None + assert len(agent_info.output_tools) == 1 + name = agent_info.output_tools[0].name + + yield {0: DeltaToolCall(name=name)} + yield {0: DeltaToolCall(json_args='{"c')} + yield {0: DeltaToolCall(json_args='ity":')} + yield {0: DeltaToolCall(json_args=' "Mex')} + yield {0: DeltaToolCall(json_args='ico City",')} + yield {0: DeltaToolCall(json_args=' "cou')} + yield {0: DeltaToolCall(json_args='ntry": "Mexico"}')} + + agent = Agent(FunctionModel(stream_function=text_stream), output_type=CityLocation) + + async with agent.iter('Generate city info') as run: + async for node in run: + if agent.is_model_request_node(node): + async with node.stream(run.ctx) as stream: + assert [c async for c in stream.stream_output(debounce_by=None)] == snapshot( + [ + CityLocation(city='Mex'), + CityLocation(city='Mexico City'), + CityLocation(city='Mexico City'), + CityLocation(city='Mexico City', country='Mexico'), + CityLocation(city='Mexico City', country='Mexico'), + ] + ) + + def test_function_tool_event_tool_call_id_properties(): """Ensure that the `tool_call_id` property on function tool events mirrors the underlying part's ID.""" # Prepare a ToolCallPart with a fixed ID From 2fca5061ed9e258462238f278ceab51cd1c5a2f8 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 28 Jul 2025 17:05:02 -0600 Subject: [PATCH 83/89] Ensure AG-UI state is isolated between requests. (#2343) --- docs/ag-ui.md | 13 ++- pydantic_ai_slim/pydantic_ai/ag_ui.py | 73 +++++++-------- tests/test_ag_ui.py | 122 +++++++++++++++++++++++--- 3 files changed, 152 insertions(+), 56 deletions(-) diff --git a/docs/ag-ui.md b/docs/ag-ui.md index 2ff90ae8f5..45bc27af87 100644 --- a/docs/ag-ui.md +++ b/docs/ag-ui.md @@ -82,10 +82,15 @@ The adapter provides full support for real-time synchronization between agents and frontend applications. In the example below we have document state which is shared between the UI and -server using the [`StateDeps`][pydantic_ai.ag_ui.StateDeps] which implements the -[`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol that can be used to automatically -decode state contained in [`RunAgentInput.state`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) -when processing requests. +server using the [`StateDeps`][pydantic_ai.ag_ui.StateDeps] [dependencies type](./dependencies.md) that can be used to automatically +validate state contained in [`RunAgentInput.state`](https://docs.ag-ui.com/sdk/js/core/types#runagentinput) using a Pydantic `BaseModel` specified as a generic parameter. + +!!! note "Custom dependencies type with AG-UI state" + If you want to use your own dependencies type to hold AG-UI state as well as other things, it needs to implements the + [`StateHandler`][pydantic_ai.ag_ui.StateHandler] protocol, meaning it needs to be a [dataclass](https://docs.python.org/3/library/dataclasses.html) with a non-optional `state` field. This lets Pydantic AI ensure that state is properly isolated between requests by building a new dependencies object each time. + + If the `state` field's type is a Pydantic `BaseModel` subclass, the raw state dictionary on the request is automatically validated. If not, you can validate the raw value yourself in your dependencies dataclass's `__post_init__` method. + ```python {title="ag_ui_state.py" py="3.10"} from pydantic import BaseModel diff --git a/pydantic_ai_slim/pydantic_ai/ag_ui.py b/pydantic_ai_slim/pydantic_ai/ag_ui.py index 416fe627ed..447a4ba60d 100644 --- a/pydantic_ai_slim/pydantic_ai/ag_ui.py +++ b/pydantic_ai_slim/pydantic_ai/ag_ui.py @@ -9,11 +9,13 @@ import json import uuid from collections.abc import Iterable, Mapping, Sequence -from dataclasses import dataclass, field +from dataclasses import Field, dataclass, field, replace from http import HTTPStatus from typing import ( + TYPE_CHECKING, Any, Callable, + ClassVar, Final, Generic, Protocol, @@ -21,6 +23,11 @@ runtime_checkable, ) +from pydantic_ai.exceptions import UserError + +if TYPE_CHECKING: + pass + try: from ag_ui.core import ( AssistantMessage, @@ -288,8 +295,24 @@ async def run( if not run_input.messages: raise _NoMessagesError + raw_state: dict[str, Any] = run_input.state or {} if isinstance(deps, StateHandler): - deps.state = run_input.state + if isinstance(deps.state, BaseModel): + try: + state = type(deps.state).model_validate(raw_state) + except ValidationError as e: # pragma: no cover + raise _InvalidStateError from e + else: + state = raw_state + + deps = replace(deps, state=state) + elif raw_state: + raise UserError( + f'AG-UI state is provided but `deps` of type `{type(deps).__name__}` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.' + ) + else: + # `deps` not being a `StateHandler` is OK if there is no state. + pass messages = _messages_from_ag_ui(run_input.messages) @@ -311,7 +334,7 @@ async def run( yield encoder.encode( RunErrorEvent(message=e.message, code=e.code), ) - except Exception as e: # pragma: no cover + except Exception as e: yield encoder.encode( RunErrorEvent(message=str(e)), ) @@ -531,7 +554,11 @@ def _messages_from_ag_ui(messages: list[Message]) -> list[ModelMessage]: @runtime_checkable class StateHandler(Protocol): - """Protocol for state handlers in agent runs.""" + """Protocol for state handlers in agent runs. Requires the class to be a dataclass with a `state` field.""" + + # Has to be a dataclass so we can use `replace` to update the state. + # From https://github.com/python/typeshed/blob/9ab7fde0a0cd24ed7a72837fcb21093b811b80d8/stdlib/_typeshed/__init__.pyi#L352 + __dataclass_fields__: ClassVar[dict[str, Field[Any]]] @property def state(self) -> State: @@ -558,6 +585,7 @@ def state(self, state: State) -> None: """Type variable for the state type, which must be a subclass of `BaseModel`.""" +@dataclass class StateDeps(Generic[StateT]): """Provides AG-UI state management. @@ -570,42 +598,7 @@ class StateDeps(Generic[StateT]): Implements the `StateHandler` protocol. """ - def __init__(self, default: StateT) -> None: - """Initialize the state with the provided state type.""" - self._state = default - - @property - def state(self) -> StateT: - """Get the current state of the agent run. - - Returns: - The current run state. - """ - return self._state - - @state.setter - def state(self, state: State) -> None: - """Set the state of the agent run. - - This method is called to update the state of the agent run with the - provided state. - - Implements the `StateHandler` protocol. - - Args: - state: The run state, which must be `None` or model validate for the state type. - - Raises: - InvalidStateError: If `state` does not validate. - """ - if state is None: - # If state is None, we keep the current state, which will be the default state. - return - - try: - self._state = type(self._state).model_validate(state) - except ValidationError as e: # pragma: no cover - raise _InvalidStateError from e + state: StateT @dataclass(repr=False) diff --git a/tests/test_ag_ui.py b/tests/test_ag_ui.py index 423011b76d..0da58aa218 100644 --- a/tests/test_ag_ui.py +++ b/tests/test_ag_ui.py @@ -7,6 +7,7 @@ import json import uuid from collections.abc import AsyncIterator +from dataclasses import dataclass from http import HTTPStatus from typing import Any @@ -17,7 +18,9 @@ from inline_snapshot import snapshot from pydantic import BaseModel +from pydantic_ai._run_context import RunContext from pydantic_ai.agent import Agent +from pydantic_ai.exceptions import UserError from pydantic_ai.messages import ModelMessage from pydantic_ai.models.function import ( AgentInfo, @@ -27,8 +30,9 @@ DeltaToolCalls, FunctionModel, ) +from pydantic_ai.models.test import TestModel from pydantic_ai.output import OutputDataT -from pydantic_ai.tools import AgentDepsT +from pydantic_ai.tools import AgentDepsT, ToolDefinition from .conftest import IsSameStr @@ -180,7 +184,7 @@ def create_input( thread_id=thread_id, run_id=uuid_str(), messages=list(messages), - state=state, + state=dict(state) if state else {}, context=[], tools=tools or [], forwarded_props=None, @@ -1050,9 +1054,19 @@ async def stream_function( async def test_request_with_state() -> None: """Test request with state modification.""" + seen_states: list[int] = [] + + async def store_state( + ctx: RunContext[StateDeps[StateInt]], tool_defs: list[ToolDefinition] + ) -> list[ToolDefinition]: + seen_states.append(ctx.deps.state.value) + ctx.deps.state.value += 1 + return tool_defs + agent: Agent[StateDeps[StateInt], str] = Agent( model=FunctionModel(stream_function=simple_stream), deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType] + prepare_tools=store_state, ) adapter = _Adapter(agent=agent) run_inputs = [ @@ -1074,32 +1088,101 @@ async def test_request_with_state() -> None: id='msg_3', content='Hello, how are you?', ), + ), + create_input( + UserMessage( + id='msg_4', + content='Hello, how are you?', + ), state=StateInt(value=42), ), ] - deps = StateDeps(StateInt()) + deps = StateDeps(StateInt(value=0)) - last_value = deps.state.value for run_input in run_inputs: events = list[dict[str, Any]]() async for event in adapter.run(run_input, deps=deps): events.append(json.loads(event.removeprefix('data: '))) assert events == simple_result() - assert deps.state.value == run_input.state.value if run_input.state is not None else last_value - last_value = deps.state.value + assert seen_states == snapshot( + [ + 41, # run msg_1, prepare_tools call 1 + 42, # run msg_1, prepare_tools call 2 + 0, # run msg_2, prepare_tools call 1 + 1, # run msg_2, prepare_tools call 2 + 0, # run msg_3, prepare_tools call 1 + 1, # run msg_3, prepare_tools call 2 + 42, # run msg_4, prepare_tools call 1 + 43, # run msg_4, prepare_tools call 2 + ] + ) + + +async def test_request_with_state_without_handler() -> None: + agent = Agent(model=FunctionModel(stream_function=simple_stream)) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Hello, how are you?', + ), + state=StateInt(value=41), + ) + + with pytest.raises( + UserError, + match='AG-UI state is provided but `deps` of type `NoneType` does not implement the `StateHandler` protocol: it needs to be a dataclass with a non-optional `state` field.', + ): + async for _ in adapter.run(run_input): + pass + + +async def test_request_with_state_with_custom_handler() -> None: + @dataclass + class CustomStateDeps: + state: dict[str, Any] + + seen_states: list[dict[str, Any]] = [] + + async def store_state(ctx: RunContext[CustomStateDeps], tool_defs: list[ToolDefinition]) -> list[ToolDefinition]: + seen_states.append(ctx.deps.state) + return tool_defs + + agent: Agent[CustomStateDeps, str] = Agent( + model=FunctionModel(stream_function=simple_stream), + deps_type=CustomStateDeps, + prepare_tools=store_state, + ) + adapter = _Adapter(agent=agent) + run_input = create_input( + UserMessage( + id='msg_1', + content='Hello, how are you?', + ), + state={'value': 42}, + ) + + async for _ in adapter.run(run_input, deps=CustomStateDeps(state={'value': 0})): + pass - assert deps.state.value == 42 + assert seen_states[-1] == {'value': 42} async def test_concurrent_runs() -> None: """Test concurrent execution of multiple runs.""" import asyncio - agent = Agent( - model=FunctionModel(stream_function=simple_stream), + agent: Agent[StateDeps[StateInt], str] = Agent( + model=TestModel(), + deps_type=StateDeps[StateInt], # type: ignore[reportUnknownArgumentType] ) + + @agent.tool + async def get_state(ctx: RunContext[StateDeps[StateInt]]) -> int: + return ctx.deps.state.value + adapter = _Adapter(agent=agent) concurrent_tasks: list[asyncio.Task[list[dict[str, Any]]]] = [] @@ -1109,10 +1192,11 @@ async def test_concurrent_runs() -> None: id=f'msg_{i}', content=f'Message {i}', ), + state=StateInt(value=i), thread_id=f'test_thread_{i}', ) - task = asyncio.create_task(collect_events_from_adapter(adapter, run_input)) + task = asyncio.create_task(collect_events_from_adapter(adapter, run_input, deps=StateDeps(StateInt()))) concurrent_tasks.append(task) results = await asyncio.gather(*concurrent_tasks) @@ -1121,9 +1205,23 @@ async def test_concurrent_runs() -> None: for i, events in enumerate(results): assert events == [ {'type': 'RUN_STARTED', 'threadId': f'test_thread_{i}', 'runId': (run_id := IsSameStr())}, + { + 'type': 'TOOL_CALL_START', + 'toolCallId': (tool_call_id := IsSameStr()), + 'toolCallName': 'get_state', + 'parentMessageId': IsStr(), + }, + {'type': 'TOOL_CALL_END', 'toolCallId': tool_call_id}, + { + 'type': 'TOOL_CALL_RESULT', + 'messageId': IsStr(), + 'toolCallId': tool_call_id, + 'content': str(i), + 'role': 'tool', + }, {'type': 'TEXT_MESSAGE_START', 'messageId': (message_id := IsSameStr()), 'role': 'assistant'}, - {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'success '}, - {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': '(no tool calls)'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': '{"get_s'}, + {'type': 'TEXT_MESSAGE_CONTENT', 'messageId': message_id, 'delta': 'tate":' + str(i) + '}'}, {'type': 'TEXT_MESSAGE_END', 'messageId': message_id}, {'type': 'RUN_FINISHED', 'threadId': f'test_thread_{i}', 'runId': run_id}, ] From ab92e67856b1ec774012d8d8c74ebe40ba2ac5f5 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Mon, 28 Jul 2025 17:23:30 -0600 Subject: [PATCH 84/89] Refine retry logic for parallel tool calling (#2317) Co-authored-by: claude[bot] <209825114+claude[bot]@users.noreply.github.com> Co-authored-by: Douwe Maan --- pydantic_ai_slim/pydantic_ai/_tool_manager.py | 20 +-- tests/test_toolsets.py | 134 +++++++++++++++++- 2 files changed, 144 insertions(+), 10 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index 657d32bde1..612be4176a 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -2,18 +2,17 @@ import json from collections.abc import Iterable -from dataclasses import dataclass, replace +from dataclasses import dataclass, field, replace from typing import Any, Generic from pydantic import ValidationError from typing_extensions import assert_never -from pydantic_ai.output import DeferredToolCalls - from . import messages as _messages from ._run_context import AgentDepsT, RunContext from .exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior from .messages import ToolCallPart +from .output import DeferredToolCalls from .tools import ToolDefinition from .toolsets.abstract import AbstractToolset, ToolsetTool @@ -28,6 +27,8 @@ class ToolManager(Generic[AgentDepsT]): """The toolset that provides the tools for this run step.""" tools: dict[str, ToolsetTool[AgentDepsT]] """The cached tools for this run step.""" + failed_tools: set[str] = field(default_factory=set) + """Names of tools that failed in this run step.""" @classmethod async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: @@ -40,7 +41,10 @@ async def build(cls, toolset: AbstractToolset[AgentDepsT], ctx: RunContext[Agent async def for_run_step(self, ctx: RunContext[AgentDepsT]) -> ToolManager[AgentDepsT]: """Build a new tool manager for the next run step, carrying over the retries from the current run step.""" - return await self.__class__.build(self.toolset, replace(ctx, retries=self.ctx.retries)) + retries = { + failed_tool_name: self.ctx.retries.get(failed_tool_name, 0) + 1 for failed_tool_name in self.failed_tools + } + return await self.__class__.build(self.toolset, replace(ctx, retries=retries)) @property def tool_defs(self) -> list[ToolDefinition]: @@ -97,7 +101,7 @@ async def _call_tool( else: args_dict = validator.validate_python(call.args or {}, allow_partial=pyd_allow_partial) - output = await self.toolset.call_tool(name, args_dict, ctx, tool) + return await self.toolset.call_tool(name, args_dict, ctx, tool) except (ValidationError, ModelRetry) as e: max_retries = tool.max_retries if tool is not None else 1 current_retry = self.ctx.retries.get(name, 0) @@ -124,12 +128,10 @@ async def _call_tool( assert_never(e) if not allow_partial: - self.ctx.retries[name] = current_retry + 1 + # If we're validating partial arguments, we don't want to count this as a failed tool as it may still succeed once the full arguments are received. + self.failed_tools.add(name) raise e - else: - self.ctx.retries.pop(name, None) - return output async def _call_tool_traced( self, call: ToolCallPart, allow_partial: bool = False, wrap_validation_errors: bool = True diff --git a/tests/test_toolsets.py b/tests/test_toolsets.py index f188d3141a..f217b34f4e 100644 --- a/tests/test_toolsets.py +++ b/tests/test_toolsets.py @@ -1,6 +1,7 @@ from __future__ import annotations import re +from collections import defaultdict from dataclasses import dataclass, replace from typing import TypeVar from unittest.mock import AsyncMock @@ -10,7 +11,7 @@ from pydantic_ai._run_context import RunContext from pydantic_ai._tool_manager import ToolManager -from pydantic_ai.exceptions import UserError +from pydantic_ai.exceptions import ModelRetry, ToolRetryError, UnexpectedModelBehavior, UserError from pydantic_ai.messages import ToolCallPart from pydantic_ai.models.test import TestModel from pydantic_ai.tools import ToolDefinition @@ -494,3 +495,134 @@ async def test_context_manager_failed_initialization(): pass assert server1.is_running is False + + +async def test_tool_manager_retry_logic(): + """Test the retry logic with failed_tools and for_run_step method.""" + + @dataclass + class TestDeps: + pass + + # Create a toolset with tools that can fail + toolset = FunctionToolset[TestDeps](max_retries=2) + call_count: defaultdict[str, int] = defaultdict(int) + + @toolset.tool + def failing_tool(x: int) -> int: + """A tool that always fails""" + call_count['failing_tool'] += 1 + raise ModelRetry('This tool always fails') + + @toolset.tool + def other_tool(x: int) -> int: + """A tool that works""" + call_count['other_tool'] += 1 + return x * 2 + + # Create initial context and tool manager + initial_context = build_run_context(TestDeps()) + tool_manager = await ToolManager[TestDeps].build(toolset, initial_context) + + # Initially no failed tools + assert tool_manager.failed_tools == set() + assert initial_context.retries == {} + + # Call the failing tool - should add to failed_tools + with pytest.raises(ToolRetryError): + await tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) + + assert tool_manager.failed_tools == {'failing_tool'} + assert call_count['failing_tool'] == 1 + + # Call the working tool - should not add to failed_tools + result = await tool_manager.handle_call(ToolCallPart(tool_name='other_tool', args={'x': 3})) + assert result == 6 + assert tool_manager.failed_tools == {'failing_tool'} # unchanged + assert call_count['other_tool'] == 1 + + # Test for_run_step - should create new tool manager with updated retry counts + new_context = build_run_context(TestDeps()) + new_tool_manager = await tool_manager.for_run_step(new_context) + + # The new tool manager should have retry count for the failed tool + assert new_tool_manager.ctx.retries == {'failing_tool': 1} + assert new_tool_manager.failed_tools == set() # reset for new run step + + # Call the failing tool again in the new manager - should have retry=1 + with pytest.raises(ToolRetryError): + await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) + + # Call the failing tool another time in the new manager + with pytest.raises(ToolRetryError): + await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) + + # Call the failing tool a third time in the new manager + with pytest.raises(ToolRetryError): + await new_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) + + assert new_tool_manager.failed_tools == {'failing_tool'} + assert call_count['failing_tool'] == 4 + + # Create another run step + another_context = build_run_context(TestDeps()) + another_tool_manager = await new_tool_manager.for_run_step(another_context) + + # Should now have retry count of 2 for failing_tool + assert another_tool_manager.ctx.retries == {'failing_tool': 2} + assert another_tool_manager.failed_tools == set() + + # Call the failing tool _again_, now we should finally hit the limit + with pytest.raises(UnexpectedModelBehavior, match="Tool 'failing_tool' exceeded max retries count of 2"): + await another_tool_manager.handle_call(ToolCallPart(tool_name='failing_tool', args={'x': 1})) + + +async def test_tool_manager_multiple_failed_tools(): + """Test retry logic when multiple tools fail in the same run step.""" + + @dataclass + class TestDeps: + pass + + toolset = FunctionToolset[TestDeps]() + + @toolset.tool + def tool_a(x: int) -> int: + """Tool A that fails""" + raise ModelRetry('Tool A fails') + + @toolset.tool + def tool_b(x: int) -> int: + """Tool B that fails""" + raise ModelRetry('Tool B fails') + + @toolset.tool + def tool_c(x: int) -> int: + """Tool C that works""" + return x * 3 + + # Create tool manager + context = build_run_context(TestDeps()) + tool_manager = await ToolManager[TestDeps].build(toolset, context) + + # Call tool_a - should fail and be added to failed_tools + with pytest.raises(ToolRetryError): + await tool_manager.handle_call(ToolCallPart(tool_name='tool_a', args={'x': 1})) + assert tool_manager.failed_tools == {'tool_a'} + + # Call tool_b - should also fail and be added to failed_tools + with pytest.raises(ToolRetryError): + await tool_manager.handle_call(ToolCallPart(tool_name='tool_b', args={'x': 1})) + assert tool_manager.failed_tools == {'tool_a', 'tool_b'} + + # Call tool_c - should succeed and not be added to failed_tools + result = await tool_manager.handle_call(ToolCallPart(tool_name='tool_c', args={'x': 2})) + assert result == 6 + assert tool_manager.failed_tools == {'tool_a', 'tool_b'} # unchanged + + # Create next run step - should have retry counts for both failed tools + new_context = build_run_context(TestDeps()) + new_tool_manager = await tool_manager.for_run_step(new_context) + + assert new_tool_manager.ctx.retries == {'tool_a': 1, 'tool_b': 1} + assert new_tool_manager.failed_tools == set() # reset for new run step From 362dcbfb62fa7ac299f042e3672247a19d9b575b Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 29 Jul 2025 07:32:45 -0600 Subject: [PATCH 85/89] Set up environment for @claude, allow it make and uv, and reading CI logs (#2344) --- .github/workflows/claude.yml | 24 +++++++++++++++++++++++- 1 file changed, 23 insertions(+), 1 deletion(-) diff --git a/.github/workflows/claude.yml b/.github/workflows/claude.yml index d4a716b7ff..8a0075e883 100644 --- a/.github/workflows/claude.yml +++ b/.github/workflows/claude.yml @@ -1,4 +1,4 @@ -name: Claude PR Assistant +name: Claude Code on: issue_comment: @@ -10,6 +10,10 @@ on: pull_request_review: types: [submitted] +env: + UV_PYTHON: 3.13 + UV_FROZEN: "1" + jobs: claude-code-action: if: | @@ -23,14 +27,32 @@ jobs: pull-requests: read issues: read id-token: write + actions: read steps: - name: Checkout repository uses: actions/checkout@v4 with: fetch-depth: 1 + - uses: astral-sh/setup-uv@v5 + with: + enable-cache: true + + - uses: denoland/setup-deno@v2 + with: + deno-version: v2.x + + - run: uv tool install pre-commit + + - run: make install + - name: Run Claude PR Action uses: anthropics/claude-code-action@beta with: anthropic_api_key: ${{ secrets.ANTHROPIC_API_KEY }} timeout_minutes: "60" + additional_permissions: | + actions: read + allowed_tools: | + Bash(make:*) + Bash(uv:*) From 8f20f9b576b5d6f465e294d8a3a2ce9e73034f1b Mon Sep 17 00:00:00 2001 From: Mohamed Amine Zghal Date: Tue, 29 Jul 2025 15:44:07 +0100 Subject: [PATCH 86/89] Remove older deprecated models of Anthropic (#2345) --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 6 ------ pydantic_ai_slim/pyproject.toml | 2 +- uv.lock | 8 ++++---- 3 files changed, 5 insertions(+), 11 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 6cdcbfbd64..67eef9eceb 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -34,8 +34,6 @@ KnownModelName = TypeAliasType( 'KnownModelName', Literal[ - 'anthropic:claude-2.0', - 'anthropic:claude-2.1', 'anthropic:claude-3-5-haiku-20241022', 'anthropic:claude-3-5-haiku-latest', 'anthropic:claude-3-5-sonnet-20240620', @@ -46,7 +44,6 @@ 'anthropic:claude-3-haiku-20240307', 'anthropic:claude-3-opus-20240229', 'anthropic:claude-3-opus-latest', - 'anthropic:claude-3-sonnet-20240229', 'anthropic:claude-4-opus-20250514', 'anthropic:claude-4-sonnet-20250514', 'anthropic:claude-opus-4-0', @@ -100,8 +97,6 @@ 'bedrock:mistral.mixtral-8x7b-instruct-v0:1', 'bedrock:mistral.mistral-large-2402-v1:0', 'bedrock:mistral.mistral-large-2407-v1:0', - 'claude-2.0', - 'claude-2.1', 'claude-3-5-haiku-20241022', 'claude-3-5-haiku-latest', 'claude-3-5-sonnet-20240620', @@ -112,7 +107,6 @@ 'claude-3-haiku-20240307', 'claude-3-opus-20240229', 'claude-3-opus-latest', - 'claude-3-sonnet-20240229', 'claude-4-opus-20250514', 'claude-4-sonnet-20250514', 'claude-opus-4-0', diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 7407e14e2a..1dfbf7ec80 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -66,7 +66,7 @@ openai = ["openai>=1.92.0"] cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.24.0"] -anthropic = ["anthropic>=0.52.0"] +anthropic = ["anthropic>=0.59.0"] groq = ["groq>=0.19.0"] mistral = ["mistralai>=1.9.2"] bedrock = ["boto3>=1.37.24"] diff --git a/uv.lock b/uv.lock index 73e1ef90c1..0bf72671c1 100644 --- a/uv.lock +++ b/uv.lock @@ -196,7 +196,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.52.0" +version = "0.60.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -207,9 +207,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/57/fd/8a9332f5baf352c272494a9d359863a53385a208954c1a7251a524071930/anthropic-0.52.0.tar.gz", hash = "sha256:f06bc924d7eb85f8a43fe587b875ff58b410d60251b7dc5f1387b322a35bd67b", size = 229372, upload-time = "2025-05-22T16:42:22.044Z" } +sdist = { url = "https://files.pythonhosted.org/packages/4e/03/3334921dc54ed822b3dd993ae72d823a7402588521bbba3e024b3333a1fd/anthropic-0.60.0.tar.gz", hash = "sha256:a22ba187c6f4fd5afecb2fc913b960feccf72bc0d25c1b7ce0345e87caede577", size = 425983, upload-time = "2025-07-28T19:53:47.685Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/a0/43/172c0031654908bbac2a87d356fff4de1b4947a9b14b9658540b69416417/anthropic-0.52.0-py3-none-any.whl", hash = "sha256:c026daa164f0e3bde36ce9cbdd27f5f1419fff03306be1e138726f42e6a7810f", size = 286076, upload-time = "2025-05-22T16:42:20Z" }, + { url = "https://files.pythonhosted.org/packages/da/bb/d84f287fb1c217b30c328af987cf8bbe3897edf0518dcc5fa39412f794ec/anthropic-0.60.0-py3-none-any.whl", hash = "sha256:65ad1f088a960217aaf82ba91ff743d6c89e9d811c6d64275b9a7c59ee9ac3c6", size = 293116, upload-time = "2025-07-28T19:53:45.944Z" }, ] [[package]] @@ -3216,7 +3216,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.8" }, - { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, + { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.59.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.37.24" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.16.0" }, From c7a3591dfb6d43678dbdb2e283d39975a81f6b75 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 29 Jul 2025 10:43:09 -0600 Subject: [PATCH 87/89] Revert "Remove older deprecated models of Anthropic" (#2358) --- pydantic_ai_slim/pydantic_ai/models/__init__.py | 6 ++++++ pydantic_ai_slim/pyproject.toml | 2 +- uv.lock | 8 ++++---- 3 files changed, 11 insertions(+), 5 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/models/__init__.py b/pydantic_ai_slim/pydantic_ai/models/__init__.py index 67eef9eceb..6cdcbfbd64 100644 --- a/pydantic_ai_slim/pydantic_ai/models/__init__.py +++ b/pydantic_ai_slim/pydantic_ai/models/__init__.py @@ -34,6 +34,8 @@ KnownModelName = TypeAliasType( 'KnownModelName', Literal[ + 'anthropic:claude-2.0', + 'anthropic:claude-2.1', 'anthropic:claude-3-5-haiku-20241022', 'anthropic:claude-3-5-haiku-latest', 'anthropic:claude-3-5-sonnet-20240620', @@ -44,6 +46,7 @@ 'anthropic:claude-3-haiku-20240307', 'anthropic:claude-3-opus-20240229', 'anthropic:claude-3-opus-latest', + 'anthropic:claude-3-sonnet-20240229', 'anthropic:claude-4-opus-20250514', 'anthropic:claude-4-sonnet-20250514', 'anthropic:claude-opus-4-0', @@ -97,6 +100,8 @@ 'bedrock:mistral.mixtral-8x7b-instruct-v0:1', 'bedrock:mistral.mistral-large-2402-v1:0', 'bedrock:mistral.mistral-large-2407-v1:0', + 'claude-2.0', + 'claude-2.1', 'claude-3-5-haiku-20241022', 'claude-3-5-haiku-latest', 'claude-3-5-sonnet-20240620', @@ -107,6 +112,7 @@ 'claude-3-haiku-20240307', 'claude-3-opus-20240229', 'claude-3-opus-latest', + 'claude-3-sonnet-20240229', 'claude-4-opus-20250514', 'claude-4-sonnet-20250514', 'claude-opus-4-0', diff --git a/pydantic_ai_slim/pyproject.toml b/pydantic_ai_slim/pyproject.toml index 1dfbf7ec80..7407e14e2a 100644 --- a/pydantic_ai_slim/pyproject.toml +++ b/pydantic_ai_slim/pyproject.toml @@ -66,7 +66,7 @@ openai = ["openai>=1.92.0"] cohere = ["cohere>=5.16.0; platform_system != 'Emscripten'"] vertexai = ["google-auth>=2.36.0", "requests>=2.32.2"] google = ["google-genai>=1.24.0"] -anthropic = ["anthropic>=0.59.0"] +anthropic = ["anthropic>=0.52.0"] groq = ["groq>=0.19.0"] mistral = ["mistralai>=1.9.2"] bedrock = ["boto3>=1.37.24"] diff --git a/uv.lock b/uv.lock index 0bf72671c1..73e1ef90c1 100644 --- a/uv.lock +++ b/uv.lock @@ -196,7 +196,7 @@ wheels = [ [[package]] name = "anthropic" -version = "0.60.0" +version = "0.52.0" source = { registry = "https://pypi.org/simple" } dependencies = [ { name = "anyio" }, @@ -207,9 +207,9 @@ dependencies = [ { name = "sniffio" }, { name = "typing-extensions" }, ] -sdist = { url = "https://files.pythonhosted.org/packages/4e/03/3334921dc54ed822b3dd993ae72d823a7402588521bbba3e024b3333a1fd/anthropic-0.60.0.tar.gz", hash = "sha256:a22ba187c6f4fd5afecb2fc913b960feccf72bc0d25c1b7ce0345e87caede577", size = 425983, upload-time = "2025-07-28T19:53:47.685Z" } +sdist = { url = "https://files.pythonhosted.org/packages/57/fd/8a9332f5baf352c272494a9d359863a53385a208954c1a7251a524071930/anthropic-0.52.0.tar.gz", hash = "sha256:f06bc924d7eb85f8a43fe587b875ff58b410d60251b7dc5f1387b322a35bd67b", size = 229372, upload-time = "2025-05-22T16:42:22.044Z" } wheels = [ - { url = "https://files.pythonhosted.org/packages/da/bb/d84f287fb1c217b30c328af987cf8bbe3897edf0518dcc5fa39412f794ec/anthropic-0.60.0-py3-none-any.whl", hash = "sha256:65ad1f088a960217aaf82ba91ff743d6c89e9d811c6d64275b9a7c59ee9ac3c6", size = 293116, upload-time = "2025-07-28T19:53:45.944Z" }, + { url = "https://files.pythonhosted.org/packages/a0/43/172c0031654908bbac2a87d356fff4de1b4947a9b14b9658540b69416417/anthropic-0.52.0-py3-none-any.whl", hash = "sha256:c026daa164f0e3bde36ce9cbdd27f5f1419fff03306be1e138726f42e6a7810f", size = 286076, upload-time = "2025-05-22T16:42:20Z" }, ] [[package]] @@ -3216,7 +3216,7 @@ dev = [ [package.metadata] requires-dist = [ { name = "ag-ui-protocol", marker = "extra == 'ag-ui'", specifier = ">=0.1.8" }, - { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.59.0" }, + { name = "anthropic", marker = "extra == 'anthropic'", specifier = ">=0.52.0" }, { name = "argcomplete", marker = "extra == 'cli'", specifier = ">=3.5.0" }, { name = "boto3", marker = "extra == 'bedrock'", specifier = ">=1.37.24" }, { name = "cohere", marker = "sys_platform != 'emscripten' and extra == 'cohere'", specifier = ">=5.16.0" }, From 3e1f6344089aacbcae930cd0304ae15d9ba6c619 Mon Sep 17 00:00:00 2001 From: Douwe Maan Date: Tue, 29 Jul 2025 13:47:08 -0600 Subject: [PATCH 88/89] Fix parallel tool calling with tools returning ToolReturn with content (#2365) --- pydantic_ai_slim/pydantic_ai/_agent_graph.py | 40 ++++---- pydantic_ai_slim/pydantic_ai/messages.py | 4 +- tests/test_tools.py | 102 ++++++++++++++++++- 3 files changed, 123 insertions(+), 23 deletions(-) diff --git a/pydantic_ai_slim/pydantic_ai/_agent_graph.py b/pydantic_ai_slim/pydantic_ai/_agent_graph.py index 312a8a2fca..12e6e07fe8 100644 --- a/pydantic_ai_slim/pydantic_ai/_agent_graph.py +++ b/pydantic_ai_slim/pydantic_ai/_agent_graph.py @@ -659,11 +659,11 @@ async def process_function_tools( # noqa: C901 for call in calls_to_run: yield _messages.FunctionToolCallEvent(call) - user_parts: list[_messages.UserPromptPart] = [] + user_parts_by_index: dict[int, list[_messages.UserPromptPart]] = defaultdict(list) if calls_to_run: # Run all tool tasks in parallel - parts_by_index: dict[int, list[_messages.ModelRequestPart]] = {} + tool_parts_by_index: dict[int, _messages.ModelRequestPart] = {} with ctx.deps.tracer.start_as_current_span( 'running tools', attributes={ @@ -681,15 +681,16 @@ async def process_function_tools( # noqa: C901 done, pending = await asyncio.wait(pending, return_when=asyncio.FIRST_COMPLETED) for task in done: index = tasks.index(task) - tool_result_part, extra_parts = task.result() - yield _messages.FunctionToolResultEvent(tool_result_part) + tool_part, tool_user_parts = task.result() + yield _messages.FunctionToolResultEvent(tool_part) - parts_by_index[index] = [tool_result_part, *extra_parts] + tool_parts_by_index[index] = tool_part + user_parts_by_index[index] = tool_user_parts # We append the results at the end, rather than as they are received, to retain a consistent ordering # This is mostly just to simplify testing - for k in sorted(parts_by_index): - output_parts.extend(parts_by_index[k]) + for k in sorted(tool_parts_by_index): + output_parts.append(tool_parts_by_index[k]) # Finally, we handle deferred tool calls for call in tool_calls_by_kind['deferred']: @@ -704,7 +705,8 @@ async def process_function_tools( # noqa: C901 else: yield _messages.FunctionToolCallEvent(call) - output_parts.extend(user_parts) + for k in sorted(user_parts_by_index): + output_parts.extend(user_parts_by_index[k]) if final_result: output_final_result.append(final_result) @@ -713,18 +715,18 @@ async def process_function_tools( # noqa: C901 async def _call_function_tool( tool_manager: ToolManager[DepsT], tool_call: _messages.ToolCallPart, -) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.ModelRequestPart]]: +) -> tuple[_messages.ToolReturnPart | _messages.RetryPromptPart, list[_messages.UserPromptPart]]: try: tool_result = await tool_manager.handle_call(tool_call) except ToolRetryError as e: return (e.tool_retry, []) - part = _messages.ToolReturnPart( + tool_part = _messages.ToolReturnPart( tool_name=tool_call.tool_name, content=tool_result, tool_call_id=tool_call.tool_call_id, ) - extra_parts: list[_messages.ModelRequestPart] = [] + user_parts: list[_messages.UserPromptPart] = [] if isinstance(tool_result, _messages.ToolReturn): if ( @@ -740,12 +742,12 @@ async def _call_function_tool( f'Please use `content` instead.' ) - part.content = tool_result.return_value # type: ignore - part.metadata = tool_result.metadata + tool_part.content = tool_result.return_value # type: ignore + tool_part.metadata = tool_result.metadata if tool_result.content: - extra_parts.append( + user_parts.append( _messages.UserPromptPart( - content=list(tool_result.content), + content=tool_result.content, part_kind='user-prompt', ) ) @@ -763,7 +765,7 @@ def process_content(content: Any) -> Any: else: identifier = multi_modal_content_identifier(content.url) - extra_parts.append( + user_parts.append( _messages.UserPromptPart( content=[f'This is file {identifier}:', content], part_kind='user-prompt', @@ -775,11 +777,11 @@ def process_content(content: Any) -> Any: if isinstance(tool_result, list): contents = cast(list[Any], tool_result) - part.content = [process_content(content) for content in contents] + tool_part.content = [process_content(content) for content in contents] else: - part.content = process_content(tool_result) + tool_part.content = process_content(tool_result) - return (part, extra_parts) + return (tool_part, user_parts) @dataclasses.dataclass diff --git a/pydantic_ai_slim/pydantic_ai/messages.py b/pydantic_ai_slim/pydantic_ai/messages.py index 51e63eea5b..b5d7be2857 100644 --- a/pydantic_ai_slim/pydantic_ai/messages.py +++ b/pydantic_ai_slim/pydantic_ai/messages.py @@ -412,8 +412,8 @@ class ToolReturn: return_value: Any """The return value to be used in the tool response.""" - content: Sequence[UserContent] | None = None - """The content sequence to be sent to the model as a UserPromptPart.""" + content: str | Sequence[UserContent] | None = None + """The content to be sent to the model as a UserPromptPart.""" metadata: Any = None """Additional data that can be accessed programmatically by the application but is not sent to the LLM.""" diff --git a/tests/test_tools.py b/tests/test_tools.py index e6a21a8915..7f4a45804b 100644 --- a/tests/test_tools.py +++ b/tests/test_tools.py @@ -13,7 +13,16 @@ from pydantic_ai import Agent, RunContext, Tool, UserError from pydantic_ai.exceptions import ModelRetry, UnexpectedModelBehavior -from pydantic_ai.messages import ModelMessage, ModelRequest, ModelResponse, TextPart, ToolCallPart, ToolReturnPart +from pydantic_ai.messages import ( + ModelMessage, + ModelRequest, + ModelResponse, + TextPart, + ToolCallPart, + ToolReturn, + ToolReturnPart, + UserPromptPart, +) from pydantic_ai.models.function import AgentInfo, FunctionModel from pydantic_ai.models.test import TestModel from pydantic_ai.output import DeferredToolCalls, ToolOutput @@ -21,8 +30,9 @@ from pydantic_ai.toolsets.deferred import DeferredToolset from pydantic_ai.toolsets.function import FunctionToolset from pydantic_ai.toolsets.prefixed import PrefixedToolset +from pydantic_ai.usage import Usage -from .conftest import IsStr +from .conftest import IsDatetime, IsStr def test_tool_no_ctx(): @@ -1321,3 +1331,91 @@ def test_output_type_deferred_tool_calls_by_itself(): def test_output_type_empty(): with pytest.raises(UserError, match='At least one output type must be provided.'): Agent(TestModel(), output_type=[]) + + +def test_parallel_tool_return(): + def llm(messages: list[ModelMessage], info: AgentInfo) -> ModelResponse: + if len(messages) == 1: + return ModelResponse( + parts=[ToolCallPart('get_price', {'fruit': 'apple'}), ToolCallPart('get_price', {'fruit': 'banana'})] + ) + else: + return ModelResponse( + parts=[ + TextPart('Done!'), + ] + ) + + agent = Agent(FunctionModel(llm)) + + @agent.tool_plain + def get_price(fruit: str) -> ToolReturn: + return ToolReturn( + return_value=10.0, + content=f'The price of {fruit} is 10.0', + metadata={'foo': 'bar'}, + ) + + result = agent.run_sync('What do an apple and a banana cost?') + + assert result.all_messages() == snapshot( + [ + ModelRequest( + parts=[ + UserPromptPart( + content='What do an apple and a banana cost?', + timestamp=IsDatetime(), + ) + ] + ), + ModelResponse( + parts=[ + ToolCallPart( + tool_name='get_price', + args={'fruit': 'apple'}, + tool_call_id=IsStr(), + ), + ToolCallPart( + tool_name='get_price', + args={'fruit': 'banana'}, + tool_call_id=IsStr(), + ), + ], + usage=Usage(requests=1, request_tokens=58, response_tokens=10, total_tokens=68), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ModelRequest( + parts=[ + ToolReturnPart( + tool_name='get_price', + content=10.0, + tool_call_id=IsStr(), + metadata={'foo': 'bar'}, + timestamp=IsDatetime(), + ), + ToolReturnPart( + tool_name='get_price', + content=10.0, + tool_call_id=IsStr(), + metadata={'foo': 'bar'}, + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of apple is 10.0', + timestamp=IsDatetime(), + ), + UserPromptPart( + content='The price of banana is 10.0', + timestamp=IsDatetime(), + ), + ] + ), + ModelResponse( + parts=[TextPart(content='Done!')], + usage=Usage(requests=1, request_tokens=76, response_tokens=11, total_tokens=87), + model_name='function:llm:', + timestamp=IsDatetime(), + ), + ] + ) From a145a85e5bcd84ff3553d524dcbb46f0a012b968 Mon Sep 17 00:00:00 2001 From: Saurabh Misra Date: Tue, 29 Jul 2025 17:15:44 -0700 Subject: [PATCH 89/89] revert changes Signed-off-by: Saurabh Misra --- .github/workflows/codeflash.yaml | 37 ------- pyproject.toml | 178 ++++++++++++++++--------------- 2 files changed, 91 insertions(+), 124 deletions(-) delete mode 100644 .github/workflows/codeflash.yaml diff --git a/.github/workflows/codeflash.yaml b/.github/workflows/codeflash.yaml deleted file mode 100644 index e836bce06a..0000000000 --- a/.github/workflows/codeflash.yaml +++ /dev/null @@ -1,37 +0,0 @@ -name: Codeflash - -on: - pull_request: - paths: - # So that this workflow only runs when code within the target module is modified - - 'pydantic_ai_slim/pydantic_ai/**' - workflow_dispatch: - -concurrency: - # Any new push to the PR will cancel the previous run, so that only the latest code is optimized - group: ${{ github.workflow }}-${{ github.ref }} - cancel-in-progress: true - - -jobs: - optimize: - name: Optimize new Python code - # Don't run codeflash on codeflash-ai[bot] commits, prevent duplicate optimizations - if: ${{ github.actor != 'codeflash-ai[bot]' }} - runs-on: ubuntu-latest - env: - CODEFLASH_API_KEY: ${{ secrets.CODEFLASH_API_KEY }} - - steps: - - name: 🛎️ Checkout - uses: actions/checkout@v4 - with: - fetch-depth: 0 - - name: 🐍 Setup UV - uses: astral-sh/setup-uv@v6 - with: - enable-cache: true - - name: 📦 Install Dependencies - run: uv sync --all-extras - - name: ⚡️Codeflash Optimization - run: uv run codeflash diff --git a/pyproject.toml b/pyproject.toml index e02d6fe91c..c3d4ecc60e 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -5,9 +5,49 @@ build-backend = "hatchling.build" [tool.hatch.version] source = "uv-dynamic-versioning" +[tool.uv-dynamic-versioning] +vcs = "git" +style = "pep440" +bump = true + +[project] +name = "pydantic-ai" +dynamic = ["version", "dependencies", "optional-dependencies"] +description = "Agent Framework / shim to use Pydantic with LLMs" +authors = [ + { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, + { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, + { name = "David Montague", email = "david@pydantic.dev" }, + { name = "Alex Hall", email = "alex@pydantic.dev" }, + { name = "Douwe Maan", email = "douwe@pydantic.dev" }, +] +license = "MIT" +readme = "README.md" +classifiers = [ + "Development Status :: 4 - Beta", + "Programming Language :: Python", + "Programming Language :: Python :: 3", + "Programming Language :: Python :: 3 :: Only", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3.11", + "Programming Language :: Python :: 3.12", + "Programming Language :: Python :: 3.13", + "Intended Audience :: Developers", + "Intended Audience :: Information Technology", + "License :: OSI Approved :: MIT License", + "Operating System :: OS Independent", + "Topic :: Internet", + "Topic :: Scientific/Engineering :: Artificial Intelligence", + "Topic :: Software Development :: Libraries :: Python Modules", + "Framework :: Pydantic", + "Framework :: Pydantic :: 2", +] +requires-python = ">=3.9" + [tool.hatch.metadata.hooks.uv-dynamic-versioning] dependencies = [ - "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,cli,mcp,evals]=={{ version }}", + "pydantic-ai-slim[openai,vertexai,google,groq,anthropic,mistral,cohere,bedrock,huggingface,cli,mcp,evals]=={{ version }}", ] [tool.hatch.metadata.hooks.uv-dynamic-versioning.optional-dependencies] @@ -15,20 +55,17 @@ examples = ["pydantic-ai-examples=={{ version }}"] logfire = ["logfire>=3.11.0"] a2a = ["fasta2a>=0.4.1"] +[project.urls] +Homepage = "https://ai.pydantic.dev" +Source = "https://github.com/pydantic/pydantic-ai" +Documentation = "https://ai.pydantic.dev" +Changelog = "https://github.com/pydantic/pydantic-ai/releases" -[tool.hatch.build.targets.wheel] -only-include = ["/README.md"] - -[tool.hatch.build.targets.sdist] -include = ["/README.md", "/Makefile", "/tests"] - - -[tool.uv-dynamic-versioning] -vcs = "git" -style = "pep440" -bump = true +[project.scripts] +pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while [tool.uv.sources] +pydantic-ai = { workspace = true } pydantic-ai-slim = { workspace = true } pydantic-evals = { workspace = true } pydantic-graph = { workspace = true } @@ -44,6 +81,25 @@ members = [ "examples", ] +[dependency-groups] +# dev dependencies are defined in `pydantic-ai-slim/pyproject.toml` to allow for minimal testing +lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"] +docs = [ + "pydantic-ai[a2a]", + "black>=24.10.0", + "mkdocs>=1.6.1", + "mkdocs-glightbox>=0.4.0", + "mkdocs-llmstxt>=0.2.0", + "mkdocs-material[imaging]>=9.5.45", + "mkdocstrings-python>=1.12.2", +] +docs-upload = ["algoliasearch>=4.12.0", "pydantic>=2.10.1"] + +[tool.hatch.build.targets.wheel] +only-include = ["/README.md"] + +[tool.hatch.build.targets.sdist] +include = ["/README.md", "/Makefile", "/tests"] [tool.ruff] line-length = 120 @@ -71,7 +127,6 @@ extend-select = [ "TID251", ] flake8-quotes = { inline-quotes = "single", multiline-quotes = "double" } -isort = { combine-as-imports = true, known-first-party = ["pydantic_ai"] } mccabe = { max-complexity = 15 } ignore = [ "D100", # ignore missing docstring in module @@ -81,6 +136,12 @@ ignore = [ "D107", # ignore missing docstring in __init__ methods ] +[tool.ruff.lint.isort] +combine-as-imports = true +known-first-party = ["pydantic_ai"] +# weird issue with ruff thinking fasta2a is still editable +known-third-party = ["fasta2a"] + [tool.ruff.lint.pydocstyle] convention = "google" @@ -98,7 +159,6 @@ quote-style = "single" "tests/**/*.py" = ["D"] "docs/**/*.py" = ["D"] - [tool.pyright] pythonVersion = "3.12" typeCheckingMode = "strict" @@ -134,15 +194,14 @@ files = "tests/typed_agent.py,tests/typed_graph.py" strict = true [tool.pytest.ini_options] -testpaths = [ - "tests", - "docs/.hooks" -] +testpaths = ["tests", "docs/.hooks"] xfail_strict = true filterwarnings = [ "error", # Issue with python-multipart - we don't want to bump the minimum version of starlette. "ignore::PendingDeprecationWarning:starlette", + # mistralai accesses model_fields on the instance, which is deprecated in Pydantic 2.11. + "ignore:Accessing the 'model_fields' attribute", # boto3 "ignore::DeprecationWarning:botocore.*", "ignore::RuntimeWarning:pydantic_ai.mcp", @@ -155,7 +214,6 @@ filterwarnings = [ ] # https://coverage.readthedocs.io/en/latest/config.html#run - [tool.coverage.run] # required to avoid warnings about files created by create_module fixture include = [ @@ -170,6 +228,20 @@ omit = [ "pydantic_ai_slim/pydantic_ai/ext/aci.py", # aci-sdk requires Python 3.10+ so cannot be added as an (optional) dependency ] branch = true +# Disable include-ignored warnings as --source is enabled automatically causing a self conflict as per: +# https://github.com/pytest-dev/pytest-cov/issues/532 +# https://github.com/pytest-dev/pytest-cov/issues/369 +# This prevents coverage being generated by pytest-cov which has direct editor support in VS Code, +# making it super useful to check coverage while writing tests. +disable_warnings = ["include-ignored"] + +[tool.coverage.paths] +# Allow CI run assets to be downloaded an replicated locally. +source = [ + ".", + "/home/runner/work/pydantic-ai/pydantic-ai", + "/System/Volumes/Data/home/runner/work/pydantic-ai/pydantic-ai" +] # https://coverage.readthedocs.io/en/latest/config.html#report [tool.coverage.report] @@ -198,7 +270,6 @@ exclude_lines = [ 'assert False', ] - [tool.logfire] ignore_no_config = true @@ -209,7 +280,6 @@ format-command = "ruff format --stdin-filename {filename}" snap-fix = ["create", "fix"] snap = ["create"] - [tool.codespell] # Ref: https://github.com/codespell-project/codespell#using-a-config-file skip = '.git*,*.svg,*.lock,*.css,*.yaml' @@ -217,69 +287,3 @@ check-hidden = true # Ignore "formatting" like **L**anguage ignore-regex = '\*\*[A-Z]\*\*[a-z]+\b' ignore-words-list = 'asend,aci' - -[tool.codeflash] -# All paths are relative to this pyproject.toml's directory. -module-root = "pydantic_ai_slim/pydantic_ai" -tests-root = "tests" -test-framework = "pytest" -ignore-paths = [] -formatter-cmds = ["disabled"] - -[project] -name = "pydantic-ai" -dynamic = ["version", "dependencies", "optional-dependencies"] -description = "Agent Framework / shim to use Pydantic with LLMs" -authors = [ - { name = "Samuel Colvin", email = "samuel@pydantic.dev" }, - { name = "Marcelo Trylesinski", email = "marcelotryle@gmail.com" }, - { name = "David Montague", email = "david@pydantic.dev" }, - { name = "Alex Hall", email = "alex@pydantic.dev" }, -] -license = "MIT" -readme = "README.md" -classifiers = [ - "Development Status :: 4 - Beta", - "Programming Language :: Python", - "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3 :: Only", - "Programming Language :: Python :: 3.9", - "Programming Language :: Python :: 3.10", - "Programming Language :: Python :: 3.11", - "Programming Language :: Python :: 3.12", - "Programming Language :: Python :: 3.13", - "Intended Audience :: Developers", - "Intended Audience :: Information Technology", - "License :: OSI Approved :: MIT License", - "Operating System :: OS Independent", - "Topic :: Internet", - "Topic :: Scientific/Engineering :: Artificial Intelligence", - "Topic :: Software Development :: Libraries :: Python Modules", - "Framework :: Pydantic", - "Framework :: Pydantic :: 2", -] -requires-python = ">=3.9" - -[project.urls] -Homepage = "https://ai.pydantic.dev" -Source = "https://github.com/pydantic/pydantic-ai" -Documentation = "https://ai.pydantic.dev" -Changelog = "https://github.com/pydantic/pydantic-ai/releases" - -[project.scripts] -pai = "pydantic_ai._cli:cli_exit" # TODO remove this when clai has been out for a while - -[dependency-groups] -# dev dependencies are defined in `pydantic-ai-slim/pyproject.toml` to allow for minimal testing -lint = ["mypy>=1.11.2", "pyright>=1.1.390", "ruff>=0.6.9"] -docs = [ - "pydantic-ai[a2a]", - "black>=24.10.0", - "mkdocs>=1.6.1", - "mkdocs-glightbox>=0.4.0", - "mkdocs-llmstxt>=0.2.0", - "mkdocs-material[imaging]>=9.5.45", - "mkdocstrings-python>=1.12.2", -] -docs-upload = ["algoliasearch>=4.12.0", "pydantic>=2.10.1"] -