@@ -131,7 +131,12 @@ def add_chat_completion_step_to_trace(**kwargs) -> None:
131131
132132
133133# ----------------------------- Tracing decorator ---------------------------- #
134- def trace (* step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [str ] = None , ** step_kwargs ):
134+ def trace (
135+ * step_args ,
136+ inference_pipeline_id : Optional [str ] = None ,
137+ context_kwarg : Optional [str ] = None ,
138+ ** step_kwargs ,
139+ ):
135140 """Decorator to trace a function.
136141
137142 Examples
@@ -175,7 +180,9 @@ def decorator(func):
175180 def wrapper (* func_args , ** func_kwargs ):
176181 if step_kwargs .get ("name" ) is None :
177182 step_kwargs ["name" ] = func .__name__
178- with create_step (* step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs ) as step :
183+ with create_step (
184+ * step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs
185+ ) as step :
179186 output = exception = None
180187 try :
181188 output = func (* func_args , ** func_kwargs )
@@ -196,7 +203,10 @@ def wrapper(*func_args, **func_kwargs):
196203 if context_kwarg in inputs :
197204 log_context (inputs .get (context_kwarg ))
198205 else :
199- logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
206+ logger .warning (
207+ "Context kwarg `%s` not found in inputs of the current function." ,
208+ context_kwarg ,
209+ )
200210
201211 step .log (
202212 inputs = inputs ,
@@ -215,7 +225,10 @@ def wrapper(*func_args, **func_kwargs):
215225
216226
217227def trace_async (
218- * step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [str ] = None , ** step_kwargs
228+ * step_args ,
229+ inference_pipeline_id : Optional [str ] = None ,
230+ context_kwarg : Optional [str ] = None ,
231+ ** step_kwargs ,
219232):
220233 """Decorator to trace a function.
221234
@@ -260,7 +273,9 @@ def decorator(func):
260273 async def wrapper (* func_args , ** func_kwargs ):
261274 if step_kwargs .get ("name" ) is None :
262275 step_kwargs ["name" ] = func .__name__
263- with create_step (* step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs ) as step :
276+ with create_step (
277+ * step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs
278+ ) as step :
264279 output = exception = None
265280 try :
266281 output = await func (* func_args , ** func_kwargs )
@@ -281,7 +296,10 @@ async def wrapper(*func_args, **func_kwargs):
281296 if context_kwarg in inputs :
282297 log_context (inputs .get (context_kwarg ))
283298 else :
284- logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
299+ logger .warning (
300+ "Context kwarg `%s` not found in inputs of the current function." ,
301+ context_kwarg ,
302+ )
285303
286304 step .log (
287305 inputs = inputs ,
@@ -299,7 +317,9 @@ async def wrapper(*func_args, **func_kwargs):
299317 return decorator
300318
301319
302- async def _invoke_with_context (coroutine : Awaitable [Any ]) -> Tuple [contextvars .Context , Any ]:
320+ async def _invoke_with_context (
321+ coroutine : Awaitable [Any ],
322+ ) -> Tuple [contextvars .Context , Any ]:
303323 """Runs a coroutine and preserves the context variables set within it."""
304324 result = await coroutine
305325 context = contextvars .copy_context ()
@@ -356,6 +376,7 @@ def post_process_trace(
356376 "cost" : processed_steps [0 ].get ("cost" , 0 ),
357377 "tokens" : processed_steps [0 ].get ("tokens" , 0 ),
358378 "steps" : processed_steps ,
379+ ** root_step .metadata ,
359380 }
360381 if input_variables :
361382 trace_data .update (input_variables )
0 commit comments