diff --git a/pydantic_ai_slim/pydantic_ai/_tool_manager.py b/pydantic_ai_slim/pydantic_ai/_tool_manager.py index fb7039e2cc..dde27df1e6 100644 --- a/pydantic_ai_slim/pydantic_ai/_tool_manager.py +++ b/pydantic_ai_slim/pydantic_ai/_tool_manager.py @@ -108,24 +108,21 @@ async def handle_call( raise ValueError('ToolManager has not been prepared for a run step yet') # pragma: no cover if (tool := self.tools.get(call.tool_name)) and tool.tool_def.kind == 'output': - # Output tool calls are not traced and not counted - return await self._call_tool( - call, - allow_partial=allow_partial, - wrap_validation_errors=wrap_validation_errors, - approved=approved, - ) + output_tool_flag = True else: - return await self._call_function_tool( - call, - allow_partial=allow_partial, - wrap_validation_errors=wrap_validation_errors, - approved=approved, - tracer=self.ctx.tracer, - include_content=self.ctx.trace_include_content, - instrumentation_version=self.ctx.instrumentation_version, - usage=self.ctx.usage, - ) + output_tool_flag = False + + return await self._call_function_tool( + call, + allow_partial=allow_partial, + wrap_validation_errors=wrap_validation_errors, + approved=approved, + tracer=self.ctx.tracer, + include_content=self.ctx.trace_include_content, + instrumentation_version=self.ctx.instrumentation_version, + usage=self.ctx.usage, + output_tool_flag=output_tool_flag, + ) async def _call_tool( self, @@ -213,16 +210,22 @@ async def _call_function_tool( include_content: bool, instrumentation_version: int, usage: RunUsage, + output_tool_flag: bool = False, ) -> Any: """See .""" instrumentation_names = InstrumentationNames.for_version(instrumentation_version) + if output_tool_flag: + tool_name = 'output tool' + else: + tool_name = call.tool_name + span_attributes = { - 'gen_ai.tool.name': call.tool_name, + 'gen_ai.tool.name': tool_name, # NOTE: this means `gen_ai.tool.call.id` will be included even if it was generated by pydantic-ai 'gen_ai.tool.call.id': call.tool_call_id, **({instrumentation_names.tool_arguments_attr: call.args_as_json_str()} if include_content else {}), - 'logfire.msg': f'running tool: {call.tool_name}', + 'logfire.msg': f'running tool: {tool_name}', # add the JSON schema so these attributes are formatted nicely in Logfire 'logfire.json_schema': json.dumps( { @@ -243,7 +246,7 @@ async def _call_function_tool( ), } with tracer.start_as_current_span( - instrumentation_names.get_tool_span_name(call.tool_name), + instrumentation_names.get_tool_span_name(tool_name), attributes=span_attributes, ) as span: try: @@ -253,7 +256,9 @@ async def _call_function_tool( wrap_validation_errors=wrap_validation_errors, approved=approved, ) - usage.tool_calls += 1 + if not output_tool_flag: + # Output tool calls are not counted + usage.tool_calls += 1 except ToolRetryError as e: part = e.tool_retry diff --git a/tests/test_dbos.py b/tests/test_dbos.py index 256aba83fb..d1ea9f8706 100644 --- a/tests/test_dbos.py +++ b/tests/test_dbos.py @@ -594,6 +594,7 @@ async def test_complex_agent_run_in_workflow(allow_model_requests: None, dbos: D ) ], ), + BasicSpan(content='running tool: output tool'), ], ) ], diff --git a/tests/test_logfire.py b/tests/test_logfire.py index dadb930dd0..333d91e81e 100644 --- a/tests/test_logfire.py +++ b/tests/test_logfire.py @@ -692,7 +692,10 @@ class MyOutput: 'id': 0, 'name': 'agent run', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -703,7 +706,10 @@ class MyOutput: 'id': 0, 'name': 'invoke_agent my_agent', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -900,7 +906,10 @@ class MyOutput: 'id': 0, 'name': 'agent run', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -911,7 +920,10 @@ class MyOutput: 'id': 0, 'name': 'invoke_agent my_agent', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -1381,8 +1393,15 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: {'id': 1, 'name': 'chat function:call_tool:', 'message': 'chat function:call_tool:'}, { 'id': 2, - 'name': 'running output function', - 'message': 'running output function: final_result', + 'name': 'running tool', + 'message': 'running tool: output tool', + 'children': [ + { + 'id': 3, + 'name': 'running output function', + 'message': 'running output function: final_result', + } + ], }, ], } @@ -1428,8 +1447,15 @@ def call_tool(_: list[ModelMessage], info: AgentInfo) -> ModelResponse: {'id': 1, 'name': 'chat function:call_tool:', 'message': 'chat function:call_tool:'}, { 'id': 2, - 'name': 'execute_tool final_result', - 'message': 'running output function: final_result', + 'name': 'execute_tool output tool', + 'message': 'running tool: output tool', + 'children': [ + { + 'id': 3, + 'name': 'execute_tool final_result', + 'message': 'running output function: final_result', + } + ], }, ], } @@ -2336,7 +2362,10 @@ def instructions(): 'id': 0, 'name': 'agent run', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -2347,7 +2376,10 @@ def instructions(): 'id': 0, 'name': 'invoke_agent my_agent', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -2589,6 +2621,7 @@ def my_tool() -> str: 'children': [{'id': 3, 'name': 'running tool', 'message': 'running tool: my_tool'}], }, {'id': 4, 'name': 'chat test', 'message': 'chat test'}, + {'id': 5, 'name': 'running tool', 'message': 'running tool: output tool'}, ], } ] @@ -2611,6 +2644,7 @@ def my_tool() -> str: ], }, {'id': 4, 'name': 'chat test', 'message': 'chat test'}, + {'id': 5, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'}, ], } ] @@ -2877,7 +2911,10 @@ def instructions(ctx: RunContext[None]): 'id': 0, 'name': 'agent run', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'running tool', 'message': 'running tool: output tool'}, + ], } ] ) @@ -2888,7 +2925,10 @@ def instructions(ctx: RunContext[None]): 'id': 0, 'name': 'invoke_agent my_agent', 'message': 'my_agent run', - 'children': [{'id': 1, 'name': 'chat test', 'message': 'chat test'}], + 'children': [ + {'id': 1, 'name': 'chat test', 'message': 'chat test'}, + {'id': 2, 'name': 'execute_tool output tool', 'message': 'running tool: output tool'}, + ], } ] ) diff --git a/tests/test_prefect.py b/tests/test_prefect.py index 3f2059ab6f..3ab9422364 100644 --- a/tests/test_prefect.py +++ b/tests/test_prefect.py @@ -553,6 +553,7 @@ async def run_complex_agent() -> Response: ) ], ), + BasicSpan(content='running tool: output tool'), ], ) ], diff --git a/tests/test_temporal.py b/tests/test_temporal.py index b3fe75d911..5d0c49ea44 100644 --- a/tests/test_temporal.py +++ b/tests/test_temporal.py @@ -718,6 +718,7 @@ async def test_complex_agent_run_in_workflow( ) ], ), + BasicSpan(content='running tool: output tool'), ], ), BasicSpan(content='CompleteWorkflow:ComplexAgentWorkflow'),