6969DEFAULT_OUTPUT_TOOL_DESCRIPTION = 'The final response which ends this conversation'
7070
7171
72- async def execute_output_function_with_span (
72+ async def execute_traced_output_function (
7373 function_schema : _function_schema .FunctionSchema ,
7474 run_context : RunContext [AgentDepsT ],
7575 args : dict [str , Any ] | Any ,
76+ wrap_validation_errors : bool = True ,
7677) -> Any :
77- """Execute a function call within a traced span, automatically recording the response."""
78+ """Execute an output function within a traced span with error handling.
79+
80+ This function executes the output function within an OpenTelemetry span for observability,
81+ automatically records the function response, and handles ModelRetry exceptions by converting
82+ them to ToolRetryError when wrap_validation_errors is True.
83+
84+ Args:
85+ function_schema: The function schema containing the function to execute
86+ run_context: The current run context containing tracing and tool information
87+ args: Arguments to pass to the function
88+ wrap_validation_errors: If True, wrap ModelRetry exceptions in ToolRetryError
89+
90+ Returns:
91+ The result of the function execution
92+
93+ Raises:
94+ ToolRetryError: When wrap_validation_errors is True and a ModelRetry is caught
95+ ModelRetry: When wrap_validation_errors is False and a ModelRetry occurs
96+ """
7897 # Set up span attributes
7998 tool_name = run_context .tool_name or getattr (function_schema .function , '__name__' , 'output_function' )
8099 attributes = {
@@ -96,7 +115,19 @@ async def execute_output_function_with_span(
96115 )
97116
98117 with run_context .tracer .start_as_current_span ('running output function' , attributes = attributes ) as span :
99- output = await function_schema .call (args , run_context )
118+ try :
119+ output = await function_schema .call (args , run_context )
120+ except ModelRetry as r :
121+ if wrap_validation_errors :
122+ m = _messages .RetryPromptPart (
123+ content = r .message ,
124+ tool_name = run_context .tool_name ,
125+ )
126+ if run_context .tool_call_id :
127+ m .tool_call_id = run_context .tool_call_id # pragma: no cover
128+ raise ToolRetryError (m ) from r
129+ else :
130+ raise
100131
101132 # Record response if content inclusion is enabled
102133 if run_context .trace_include_content and span .is_recording ():
@@ -663,16 +694,7 @@ async def process(
663694 else :
664695 raise
665696
666- try :
667- output = await self .call (output , run_context )
668- except ModelRetry as r :
669- if wrap_validation_errors :
670- m = _messages .RetryPromptPart (
671- content = r .message ,
672- )
673- raise ToolRetryError (m ) from r
674- else :
675- raise # pragma: no cover
697+ output = await self .call (output , run_context , wrap_validation_errors )
676698
677699 return output
678700
@@ -691,12 +713,15 @@ async def call(
691713 self ,
692714 output : Any ,
693715 run_context : RunContext [AgentDepsT ],
716+ wrap_validation_errors : bool = True ,
694717 ):
695718 if k := self .outer_typed_dict_key :
696719 output = output [k ]
697720
698721 if self ._function_schema :
699- output = await execute_output_function_with_span (self ._function_schema , run_context , output )
722+ output = await execute_traced_output_function (
723+ self ._function_schema , run_context , output , wrap_validation_errors
724+ )
700725
701726 return output
702727
@@ -856,16 +881,7 @@ async def process(
856881 wrap_validation_errors : bool = True ,
857882 ) -> OutputDataT :
858883 args = {self ._str_argument_name : data }
859- try :
860- output = await execute_output_function_with_span (self ._function_schema , run_context , args )
861- except ModelRetry as r :
862- if wrap_validation_errors :
863- m = _messages .RetryPromptPart (
864- content = r .message ,
865- )
866- raise ToolRetryError (m ) from r
867- else :
868- raise # pragma: no cover
884+ output = await execute_traced_output_function (self ._function_schema , run_context , args , wrap_validation_errors )
869885
870886 return cast (OutputDataT , output )
871887
@@ -975,7 +991,7 @@ async def get_tools(self, ctx: RunContext[AgentDepsT]) -> dict[str, ToolsetTool[
975991 async def call_tool (
976992 self , name : str , tool_args : dict [str , Any ], ctx : RunContext [AgentDepsT ], tool : ToolsetTool [AgentDepsT ]
977993 ) -> Any :
978- output = await self .processors [name ].call (tool_args , ctx )
994+ output = await self .processors [name ].call (tool_args , ctx , wrap_validation_errors = False )
979995 for validator in self .output_validators :
980996 output = await validator .validate (output , ctx , wrap_validation_errors = False )
981997 return output
0 commit comments