99
1010from ..tracing import tracer
1111
12- LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat" : "OpenAI" }
13- PROVIDER_TO_STEP_NAME = {"OpenAI" : "OpenAI Chat Completion" }
12+ LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat" : "OpenAI" , "chat-ollama" : "Ollama" }
13+ PROVIDER_TO_STEP_NAME = {"OpenAI" : "OpenAI Chat Completion" , "Ollama" : "Ollama Chat Completion" }
1414
1515
1616class OpenlayerHandler (BaseCallbackHandler ):
@@ -45,13 +45,16 @@ def on_chat_model_start(
4545 ) -> Any :
4646 """Run when Chat Model starts running."""
4747 self .model_parameters = kwargs .get ("invocation_params" , {})
48+ self .metadata = kwargs .get ("metadata" , {})
4849
4950 provider = self .model_parameters .get ("_type" , None )
5051 if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP :
5152 self .provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP [provider ]
5253 self .model_parameters .pop ("_type" )
54+ self .metadata .pop ("ls_provider" , None )
55+ self .metadata .pop ("ls_model_type" , None )
5356
54- self .model = self .model_parameters .get ("model_name" , None )
57+ self .model = self .model_parameters .get ("model_name" , None ) or self . metadata . pop ( "ls_model_name" , None )
5558 self .output = ""
5659 self .prompt = self ._langchain_messages_to_prompt (messages )
5760 self .start_time = time .time ()
@@ -82,17 +85,32 @@ def on_llm_end(self, response: langchain_schema.LLMResult, **kwargs: Any) -> Any
8285 self .end_time = time .time ()
8386 self .latency = (self .end_time - self .start_time ) * 1000
8487
85- if response . llm_output and "token_usage" in response . llm_output :
86- self .prompt_tokens = response . llm_output [ "token_usage" ]. get ( "prompt_tokens" , 0 )
87- self .completion_tokens = response . llm_output [ "token_usage" ]. get ( "completion_tokens" , 0 )
88- self .total_tokens = response . llm_output [ "token_usage" ]. get ( "total_tokens" , 0 )
88+ if self . provider == "OpenAI" :
89+ self ._openai_token_information ( response )
90+ elif self .provider == "Ollama" :
91+ self ._ollama_token_information ( response )
8992
9093 for generations in response .generations :
9194 for generation in generations :
9295 self .output += generation .text .replace ("\n " , " " )
9396
9497 self ._add_to_trace ()
9598
99+ def _openai_token_information (self , response : langchain_schema .LLMResult ) -> None :
100+ """Extracts OpenAI's token information."""
101+ if response .llm_output and "token_usage" in response .llm_output :
102+ self .prompt_tokens = response .llm_output ["token_usage" ].get ("prompt_tokens" , 0 )
103+ self .completion_tokens = response .llm_output ["token_usage" ].get ("completion_tokens" , 0 )
104+ self .total_tokens = response .llm_output ["token_usage" ].get ("total_tokens" , 0 )
105+
106+ def _ollama_token_information (self , response : langchain_schema .LLMResult ) -> None :
107+ """Extracts Ollama's token information."""
108+ generation_info = response .generations [0 ][0 ].generation_info
109+ if generation_info :
110+ self .prompt_tokens = generation_info .get ("prompt_eval_count" , 0 )
111+ self .completion_tokens = generation_info .get ("eval_count" , 0 )
112+ self .total_tokens = self .prompt_tokens + self .completion_tokens
113+
96114 def _add_to_trace (self ) -> None :
97115 """Adds to the trace."""
98116 name = PROVIDER_TO_STEP_NAME .get (self .provider , "Chat Completion Model" )
@@ -109,7 +127,7 @@ def _add_to_trace(self) -> None:
109127 model_parameters = self .model_parameters ,
110128 prompt_tokens = self .prompt_tokens ,
111129 completion_tokens = self .completion_tokens ,
112- metadata = self .metatada ,
130+ metadata = self .metadata ,
113131 )
114132
115133 def on_llm_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
0 commit comments