22
33# pylint: disable=unused-argument
44import time
5- from typing import Any , Dict , List , Optional , Union , TYPE_CHECKING
5+ from typing import Any , Dict , List , Optional , Union , TYPE_CHECKING , Callable
66from uuid import UUID
77
88try :
@@ -52,6 +52,8 @@ def __init__(self, **kwargs: Any) -> None:
5252 self .root_steps : set [UUID ] = set () # Track which steps are root
5353 # Extract inference_id from kwargs if provided
5454 self ._inference_id = kwargs .get ("inference_id" )
55+ # Extract metadata_transformer from kwargs if provided
56+ self ._metadata_transformer = kwargs .get ("metadata_transformer" )
5557
5658 def _start_step (
5759 self ,
@@ -207,6 +209,25 @@ def _process_and_upload_trace(self, root_step: steps.Step) -> None:
207209 # Reset trace context only for standalone traces
208210 tracer ._current_trace .set (None )
209211
212+ def _process_metadata (self , metadata : Dict [str , Any ]) -> Dict [str , Any ]:
213+ """Apply user-defined metadata transformation if provided."""
214+ if not metadata :
215+ return {}
216+
217+ # First convert LangChain objects to JSON-serializable format
218+ converted_metadata = self ._convert_langchain_objects (metadata )
219+
220+ # Then apply custom transformer if provided
221+ if self ._metadata_transformer :
222+ try :
223+ return self ._metadata_transformer (converted_metadata )
224+ except Exception as e :
225+ # Log warning but continue with unconverted metadata
226+ tracer .logger .warning (f"Metadata transformer failed: { e } " )
227+ return converted_metadata
228+
229+ return converted_metadata
230+
210231 def _convert_step_objects_recursively (self , step : steps .Step ) -> None :
211232 """Convert all LangChain objects in a step and its nested steps."""
212233 # Convert step attributes
@@ -217,7 +238,7 @@ def _convert_step_objects_recursively(self, step: steps.Step) -> None:
217238 converted_output = self ._convert_langchain_objects (step .output )
218239 step .output = utils .json_serialize (converted_output )
219240 if step .metadata is not None :
220- step .metadata = self ._convert_langchain_objects (step .metadata )
241+ step .metadata = self ._process_metadata (step .metadata )
221242
222243 # Convert nested steps recursively
223244 for nested_step in step .steps :
@@ -754,11 +775,16 @@ def __init__(
754775 ignore_retriever = False ,
755776 ignore_agent = False ,
756777 inference_id : Optional [Any ] = None ,
778+ metadata_transformer : Optional [
779+ Callable [[Dict [str , Any ]], Dict [str , Any ]]
780+ ] = None ,
757781 ** kwargs : Any ,
758782 ) -> None :
759- # Add inference_id to kwargs so it gets passed to mixin
783+ # Add both inference_id and metadata_transformer to kwargs so they get passed to mixin
760784 if inference_id is not None :
761785 kwargs ["inference_id" ] = inference_id
786+ if metadata_transformer is not None :
787+ kwargs ["metadata_transformer" ] = metadata_transformer
762788 super ().__init__ (** kwargs )
763789 # Store the ignore flags as instance variables
764790 self ._ignore_llm = ignore_llm
@@ -900,11 +926,16 @@ def __init__(
900926 ignore_retriever = False ,
901927 ignore_agent = False ,
902928 inference_id : Optional [Any ] = None ,
929+ metadata_transformer : Optional [
930+ Callable [[Dict [str , Any ]], Dict [str , Any ]]
931+ ] = None ,
903932 ** kwargs : Any ,
904933 ) -> None :
905- # Add inference_id to kwargs so it gets passed to mixin
934+ # Add both inference_id and metadata_transformer to kwargs so they get passed to mixin
906935 if inference_id is not None :
907936 kwargs ["inference_id" ] = inference_id
937+ if metadata_transformer is not None :
938+ kwargs ["metadata_transformer" ] = metadata_transformer
908939 super ().__init__ (** kwargs )
909940 # Store the ignore flags as instance variables
910941 self ._ignore_llm = ignore_llm
0 commit comments