99
1010from ..tracing import tracer
1111
12- LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat" : "OpenAI" , "chat-ollama" : "Ollama" }
13- PROVIDER_TO_STEP_NAME = {"OpenAI" : "OpenAI Chat Completion" , "Ollama" : "Ollama Chat Completion" }
12+ LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP = {"openai-chat" : "OpenAI" , "chat-ollama" : "Ollama" , "vertexai" : "Google" }
13+ PROVIDER_TO_STEP_NAME = {
14+ "OpenAI" : "OpenAI Chat Completion" ,
15+ "Ollama" : "Ollama Chat Completion" ,
16+ "Google" : "Google Vertex AI Chat Completion" ,
17+ }
1418
1519
1620class OpenlayerHandler (BaseCallbackHandler ):
@@ -29,13 +33,28 @@ def __init__(self, **kwargs: Any) -> None:
2933 self .prompt_tokens : int = None
3034 self .completion_tokens : int = None
3135 self .total_tokens : int = None
32- self .output : str = None
33- self .metatada : Dict [str , Any ] = kwargs or {}
36+ self .output : str = ""
37+ self .metadata : Dict [str , Any ] = kwargs or {}
3438
3539 # noqa arg002
3640 def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
3741 """Run when LLM starts running."""
38- pass
42+ self ._initialize_run (kwargs )
43+ self .prompt = [{"role" : "user" , "content" : text } for text in prompts ]
44+ self .start_time = time .time ()
45+
46+ def _initialize_run (self , kwargs : Dict [str , Any ]) -> None :
47+ """Initializes an LLM (or Chat) run, extracting the provider, model name,
48+ and other metadata."""
49+ self .model_parameters = kwargs .get ("invocation_params" , {})
50+ metadata = kwargs .get ("metadata" , {})
51+
52+ provider = self .model_parameters .pop ("_type" , None )
53+ if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP :
54+ self .provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP [provider ]
55+
56+ self .model = self .model_parameters .get ("model_name" , None ) or metadata .get ("ls_model_name" , None )
57+ self .output = ""
3958
4059 def on_chat_model_start (
4160 self ,
@@ -44,18 +63,7 @@ def on_chat_model_start(
4463 ** kwargs : Any ,
4564 ) -> Any :
4665 """Run when Chat Model starts running."""
47- self .model_parameters = kwargs .get ("invocation_params" , {})
48- self .metadata = kwargs .get ("metadata" , {})
49-
50- provider = self .model_parameters .get ("_type" , None )
51- if provider in LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP :
52- self .provider = LANGCHAIN_TO_OPENLAYER_PROVIDER_MAP [provider ]
53- self .model_parameters .pop ("_type" )
54- self .metadata .pop ("ls_provider" , None )
55- self .metadata .pop ("ls_model_type" , None )
56-
57- self .model = self .model_parameters .get ("model_name" , None ) or self .metadata .pop ("ls_model_name" , None )
58- self .output = ""
66+ self ._initialize_run (kwargs )
5967 self .prompt = self ._langchain_messages_to_prompt (messages )
6068 self .start_time = time .time ()
6169
@@ -83,18 +91,20 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
8391 def on_llm_end (self , response : langchain_schema .LLMResult , ** kwargs : Any ) -> Any : # noqa: ARG002, E501
8492 """Run when LLM ends running."""
8593 self .end_time = time .time ()
86- self .latency = (self .end_time - self .start_time ) * 1000
94+ self .latency = (self .end_time - self .start_time ) * 1000 # in milliseconds
95+
96+ self ._extract_token_information (response = response )
97+ self ._extract_output (response = response )
98+ self ._add_to_trace ()
8799
100+ def _extract_token_information (self , response : langchain_schema .LLMResult ) -> None :
101+ """Extract token information based on provider."""
88102 if self .provider == "OpenAI" :
89103 self ._openai_token_information (response )
90104 elif self .provider == "Ollama" :
91105 self ._ollama_token_information (response )
92-
93- for generations in response .generations :
94- for generation in generations :
95- self .output += generation .text .replace ("\n " , " " )
96-
97- self ._add_to_trace ()
106+ elif self .provider == "Google" :
107+ self ._google_token_information (response )
98108
99109 def _openai_token_information (self , response : langchain_schema .LLMResult ) -> None :
100110 """Extracts OpenAI's token information."""
@@ -111,6 +121,20 @@ def _ollama_token_information(self, response: langchain_schema.LLMResult) -> Non
111121 self .completion_tokens = generation_info .get ("eval_count" , 0 )
112122 self .total_tokens = self .prompt_tokens + self .completion_tokens
113123
124+ def _google_token_information (self , response : langchain_schema .LLMResult ) -> None :
125+ """Extracts Google Vertex AI token information."""
126+ usage_metadata = response .generations [0 ][0 ].generation_info ["usage_metadata" ]
127+ if usage_metadata :
128+ self .prompt_tokens = usage_metadata .get ("prompt_token_count" , 0 )
129+ self .completion_tokens = usage_metadata .get ("candidates_token_count" , 0 )
130+ self .total_tokens = usage_metadata .get ("total_token_count" , 0 )
131+
132+ def _extract_output (self , response : langchain_schema .LLMResult ) -> None :
133+ """Extracts the output from the response."""
134+ for generations in response .generations :
135+ for generation in generations :
136+ self .output += generation .text .replace ("\n " , " " )
137+
114138 def _add_to_trace (self ) -> None :
115139 """Adds to the trace."""
116140 name = PROVIDER_TO_STEP_NAME .get (self .provider , "Chat Completion Model" )
0 commit comments