2323
2424_current_step = contextvars .ContextVar ("current_step" )
2525_current_trace = contextvars .ContextVar ("current_trace" )
26+ _rag_context = contextvars .ContextVar ("rag_context" )
2627
2728
2829def get_current_trace () -> Optional [traces .Trace ]:
@@ -35,6 +36,11 @@ def get_current_step() -> Optional[steps.Step]:
3536 return _current_step .get (None )
3637
3738
39+ def get_rag_context () -> Optional [Dict [str , Any ]]:
40+ """Returns the current context."""
41+ return _rag_context .get (None )
42+
43+
3844@contextmanager
3945def create_step (
4046 name : str ,
@@ -57,6 +63,7 @@ def create_step(
5763 logger .debug ("Starting a new trace..." )
5864 current_trace = traces .Trace ()
5965 _current_trace .set (current_trace ) # Set the current trace in context
66+ _rag_context .set (None ) # Reset the context
6067 current_trace .add_step (new_step )
6168 else :
6269 logger .debug ("Adding step %s to parent step %s" , name , parent_step .name )
@@ -91,6 +98,9 @@ def create_step(
9198 )
9299 )
93100
101+ if "context" in trace_data :
102+ config .update ({"context_column_name" : "context" })
103+
94104 if isinstance (new_step , steps .ChatCompletionStep ):
95105 config .update (
96106 {
@@ -121,7 +131,7 @@ def add_chat_completion_step_to_trace(**kwargs) -> None:
121131
122132
123133# ----------------------------- Tracing decorator ---------------------------- #
124- def trace (* step_args , inference_pipeline_id : Optional [str ] = None , ** step_kwargs ):
134+ def trace (* step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [ str ] = None , ** step_kwargs ):
125135 """Decorator to trace a function.
126136
127137 Examples
@@ -182,6 +192,12 @@ def wrapper(*func_args, **func_kwargs):
182192 inputs .pop ("self" , None )
183193 inputs .pop ("cls" , None )
184194
195+ if context_kwarg :
196+ if context_kwarg in inputs :
197+ log_context (inputs .get (context_kwarg ))
198+ else :
199+ logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
200+
185201 step .log (
186202 inputs = inputs ,
187203 output = output ,
@@ -198,7 +214,9 @@ def wrapper(*func_args, **func_kwargs):
198214 return decorator
199215
200216
201- def trace_async (* step_args , inference_pipeline_id : Optional [str ] = None , ** step_kwargs ):
217+ def trace_async (
218+ * step_args , inference_pipeline_id : Optional [str ] = None , context_kwarg : Optional [str ] = None , ** step_kwargs
219+ ):
202220 """Decorator to trace a function.
203221
204222 Examples
@@ -259,6 +277,12 @@ async def wrapper(*func_args, **func_kwargs):
259277 inputs .pop ("self" , None )
260278 inputs .pop ("cls" , None )
261279
280+ if context_kwarg :
281+ if context_kwarg in inputs :
282+ log_context (inputs .get (context_kwarg ))
283+ else :
284+ logger .warning ("Context kwarg `%s` not found in inputs of the current function." , context_kwarg )
285+
262286 step .log (
263287 inputs = inputs ,
264288 output = output ,
@@ -292,6 +316,19 @@ def run_async_func(coroutine: Awaitable[Any]) -> Any:
292316 return result
293317
294318
319+ def log_context (context : List [str ]) -> None :
320+ """Logs context information to the current step of the trace.
321+
322+ The `context` parameter should be a list of strings representing the
323+ context chunks retrieved by the context retriever."""
324+ current_step = get_current_step ()
325+ if current_step :
326+ _rag_context .set (context )
327+ current_step .log (metadata = {"context" : context })
328+ else :
329+ logger .warning ("No current step found to log context." )
330+
331+
295332# --------------------- Helper post-processing functions --------------------- #
296333def post_process_trace (
297334 trace_obj : traces .Trace ,
@@ -323,4 +360,8 @@ def post_process_trace(
323360 if input_variables :
324361 trace_data .update (input_variables )
325362
363+ context = get_rag_context ()
364+ if context :
365+ trace_data ["context" ] = context
366+
326367 return trace_data , input_variable_names
0 commit comments