@@ -492,7 +492,16 @@ def _handle_llm_end(
492492 return
493493
494494 output = self ._extract_output (response )
495- token_info = self ._extract_token_info (response )
495+
496+ # Only extract token info if it hasn't been set during streaming
497+ step = self .steps [run_id ]
498+ token_info = {}
499+ if not (
500+ hasattr (step , "prompt_tokens" )
501+ and step .prompt_tokens is not None
502+ and step .prompt_tokens > 0
503+ ):
504+ token_info = self ._extract_token_info (response )
496505
497506 self ._end_step (
498507 run_id = run_id ,
@@ -763,6 +772,35 @@ def _handle_retriever_error(
763772 """Common logic for retriever error."""
764773 self ._end_step (run_id = run_id , parent_run_id = parent_run_id , error = str (error ))
765774
775+ def _handle_llm_new_token (self , token : str , ** kwargs : Any ) -> Any :
776+ """Common logic for LLM new token."""
777+ # Safely check for chunk and usage_metadata
778+ chunk = kwargs .get ("chunk" )
779+ if (
780+ chunk
781+ and hasattr (chunk , "message" )
782+ and hasattr (chunk .message , "usage_metadata" )
783+ ):
784+ usage = chunk .message .usage_metadata
785+
786+ # Only proceed if usage is not None
787+ if usage :
788+ # Extract run_id from kwargs (should be provided by LangChain)
789+ run_id = kwargs .get ("run_id" )
790+ if run_id and run_id in self .steps :
791+ # Convert usage to the expected format like _extract_token_info does
792+ token_info = {
793+ "prompt_tokens" : usage .get ("input_tokens" , 0 ),
794+ "completion_tokens" : usage .get ("output_tokens" , 0 ),
795+ "tokens" : usage .get ("total_tokens" , 0 ),
796+ }
797+
798+ # Update the step with token usage information
799+ step = self .steps [run_id ]
800+ if isinstance (step , steps .ChatCompletionStep ):
801+ step .log (** token_info )
802+ return
803+
766804
767805class OpenlayerHandler (OpenlayerHandlerMixin , BaseCallbackHandlerClass ): # type: ignore[misc]
768806 """LangChain callback handler that logs to Openlayer."""
@@ -848,7 +886,7 @@ def on_llm_error(
848886
849887 def on_llm_new_token (self , token : str , ** kwargs : Any ) -> Any :
850888 """Run on new LLM token. Only available when streaming is enabled."""
851- pass
889+ return self . _handle_llm_new_token ( token , ** kwargs )
852890
853891 def on_chain_start (
854892 self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
@@ -1137,7 +1175,7 @@ async def on_llm_error(
11371175 return self ._handle_llm_error (error , ** kwargs )
11381176
11391177 async def on_llm_new_token (self , token : str , ** kwargs : Any ) -> Any :
1140- pass
1178+ return self . _handle_llm_new_token ( token , ** kwargs )
11411179
11421180 async def on_chain_start (
11431181 self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
0 commit comments