@@ -208,34 +208,113 @@ def trace(
208208 def decorator (func ):
209209 func_signature = inspect .signature (func )
210210
211- @wraps (func )
212- def wrapper (* func_args , ** func_kwargs ):
213- if step_kwargs .get ("name" ) is None :
214- step_kwargs ["name" ] = func .__name__
215-
216- with create_step (* step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs ) as step :
217- output = exception = None
218- try :
219- output = func (* func_args , ** func_kwargs )
220- except Exception as exc :
221- _log_step_exception (step , exc )
222- exception = exc
223-
224- # Extract inputs and finalize logging using optimized helper
225- _process_wrapper_inputs_and_outputs (
226- step = step ,
227- func_signature = func_signature ,
228- func_args = func_args ,
229- func_kwargs = func_kwargs ,
230- context_kwarg = context_kwarg ,
231- output = output ,
232- )
211+ if step_kwargs .get ("name" ) is None :
212+ step_kwargs ["name" ] = func .__name__
213+ step_name = step_kwargs ["name" ]
214+
215+ # Check if it's a generator function
216+ if inspect .isgeneratorfunction (func ):
217+ # For sync generators, use class-based approach to delay trace creation
218+ # until actual iteration begins (not when generator object is created)
219+ @wraps (func )
220+ def sync_generator_wrapper (* func_args , ** func_kwargs ):
221+ class TracedSyncGenerator :
222+ def __init__ (self ):
223+ self ._original_gen = None
224+ self ._step = None
225+ self ._is_root_step = False
226+ self ._token = None
227+ self ._output_chunks = []
228+ self ._trace_initialized = False
229+
230+ def __iter__ (self ):
231+ return self
232+
233+ def __next__ (self ):
234+ # Initialize tracing on first iteration only
235+ if not self ._trace_initialized :
236+ self ._original_gen = func (* func_args , ** func_kwargs )
237+ self ._step , self ._is_root_step , self ._token = _create_and_initialize_step (
238+ step_name = step_name ,
239+ step_type = enums .StepType .USER_CALL ,
240+ inputs = None ,
241+ output = None ,
242+ metadata = None ,
243+ )
244+ self ._inputs = _extract_function_inputs (
245+ func_signature = func_signature ,
246+ func_args = func_args ,
247+ func_kwargs = func_kwargs ,
248+ context_kwarg = context_kwarg ,
249+ )
250+ self ._trace_initialized = True
251+
252+ try :
253+ chunk = next (self ._original_gen )
254+ self ._output_chunks .append (chunk )
255+ return chunk
256+ except StopIteration :
257+ # Finalize trace when generator is exhausted
258+ output = _join_output_chunks (self ._output_chunks )
259+ _finalize_sync_generator_step (
260+ step = self ._step ,
261+ token = self ._token ,
262+ is_root_step = self ._is_root_step ,
263+ step_name = step_name ,
264+ inputs = self ._inputs ,
265+ output = output ,
266+ inference_pipeline_id = inference_pipeline_id ,
267+ )
268+ raise
269+ except Exception as exc :
270+ # Handle exceptions
271+ if self ._step :
272+ _log_step_exception (self ._step , exc )
273+ output = _join_output_chunks (self ._output_chunks )
274+ _finalize_sync_generator_step (
275+ step = self ._step ,
276+ token = self ._token ,
277+ is_root_step = self ._is_root_step ,
278+ step_name = step_name ,
279+ inputs = self ._inputs ,
280+ output = output ,
281+ inference_pipeline_id = inference_pipeline_id ,
282+ )
283+ raise
284+
285+ return TracedSyncGenerator ()
286+
287+ return sync_generator_wrapper
288+ else :
289+ # Handle regular functions
290+ @wraps (func )
291+ def wrapper (* func_args , ** func_kwargs ):
292+ if step_kwargs .get ("name" ) is None :
293+ step_kwargs ["name" ] = func .__name__
294+
295+ with create_step (* step_args , inference_pipeline_id = inference_pipeline_id , ** step_kwargs ) as step :
296+ output = exception = None
297+ try :
298+ output = func (* func_args , ** func_kwargs )
299+ except Exception as exc :
300+ _log_step_exception (step , exc )
301+ exception = exc
233302
234- if exception is not None :
235- raise exception
236- return output
303+ # Extract inputs and finalize logging using optimized helper
304+ _process_wrapper_inputs_and_outputs (
305+ step = step ,
306+ func_signature = func_signature ,
307+ func_args = func_args ,
308+ func_kwargs = func_kwargs ,
309+ context_kwarg = context_kwarg ,
310+ output = output ,
311+ )
237312
238- return wrapper
313+ if exception is not None :
314+ raise exception
315+ return output
316+
317+ return wrapper
239318
240319 return decorator
241320
@@ -637,7 +716,26 @@ def _finalize_step_logging(
637716 )
638717
639718
640- # ----------------------------- Async generator specific functions ----------------------------- #
719+ # ----------------------------- Generator specific functions ----------------------------- #
720+
721+
722+ def _finalize_sync_generator_step (
723+ step : steps .Step ,
724+ token : Any ,
725+ is_root_step : bool ,
726+ step_name : str ,
727+ inputs : dict ,
728+ output : Any ,
729+ inference_pipeline_id : Optional [str ] = None ,
730+ ) -> None :
731+ """Finalize sync generator step - called when generator is consumed."""
732+ _current_step .reset (token )
733+ _finalize_step_logging (step = step , inputs = inputs , output = output , start_time = step .start_time )
734+ _handle_trace_completion (
735+ is_root_step = is_root_step ,
736+ step_name = step_name ,
737+ inference_pipeline_id = inference_pipeline_id ,
738+ )
641739
642740
643741def _finalize_async_generator_step (
0 commit comments