22
33# pylint: disable=unused-argument
44import time
5- from typing import Any , Dict , List , Union , Optional
5+ from typing import Any , Dict , List , Optional , Union
66
77from langchain import schema as langchain_schema
88from langchain .callbacks .base import BaseCallbackHandler
@@ -35,9 +35,7 @@ def __init__(self, **kwargs: Any) -> None:
3535 self .metatada : Dict [str , Any ] = kwargs or {}
3636
3737 # noqa arg002
38- def on_llm_start (
39- self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any
40- ) -> Any :
38+ def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
4139 """Run when LLM starts running."""
4240 pass
4341
@@ -81,45 +79,32 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
8179 """Run on new LLM token. Only available when streaming is enabled."""
8280 pass
8381
84- def on_llm_end (
85- self , response : langchain_schema .LLMResult , ** kwargs : Any # noqa: ARG002, E501
86- ) -> Any :
82+ def on_llm_end (self , response : langchain_schema .LLMResult , ** kwargs : Any ) -> Any : # noqa: ARG002, E501
8783 """Run when LLM ends running."""
8884 self .end_time = time .time ()
8985 self .latency = (self .end_time - self .start_time ) * 1000
9086
9187 if response .llm_output and "token_usage" in response .llm_output :
92- self .prompt_tokens = response .llm_output ["token_usage" ].get (
93- "prompt_tokens" , 0
94- )
95- self .completion_tokens = response .llm_output ["token_usage" ].get (
96- "completion_tokens" , 0
97- )
88+ self .prompt_tokens = response .llm_output ["token_usage" ].get ("prompt_tokens" , 0 )
89+ self .completion_tokens = response .llm_output ["token_usage" ].get ("completion_tokens" , 0 )
9890 self .cost = self ._get_cost_estimate (
9991 num_input_tokens = self .prompt_tokens ,
10092 num_output_tokens = self .completion_tokens ,
10193 )
102- self .total_tokens = response .llm_output ["token_usage" ].get (
103- "total_tokens" , 0
104- )
94+ self .total_tokens = response .llm_output ["token_usage" ].get ("total_tokens" , 0 )
10595
10696 for generations in response .generations :
10797 for generation in generations :
10898 self .output += generation .text .replace ("\n " , " " )
10999
110100 self ._add_to_trace ()
111101
112- def _get_cost_estimate (
113- self , num_input_tokens : int , num_output_tokens : int
114- ) -> float :
102+ def _get_cost_estimate (self , num_input_tokens : int , num_output_tokens : int ) -> float :
115103 """Returns the cost estimate for a given model and number of tokens."""
116104 if self .model not in constants .OPENAI_COST_PER_TOKEN :
117105 return None
118106 cost_per_token = constants .OPENAI_COST_PER_TOKEN [self .model ]
119- return (
120- cost_per_token ["input" ] * num_input_tokens
121- + cost_per_token ["output" ] * num_output_tokens
122- )
107+ return cost_per_token ["input" ] * num_input_tokens + cost_per_token ["output" ] * num_output_tokens
123108
124109 def _add_to_trace (self ) -> None :
125110 """Adds to the trace."""
@@ -141,56 +126,42 @@ def _add_to_trace(self) -> None:
141126 metadata = self .metatada ,
142127 )
143128
144- def on_llm_error (
145- self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
146- ) -> Any :
129+ def on_llm_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
147130 """Run when LLM errors."""
148131 pass
149132
150- def on_chain_start (
151- self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
152- ) -> Any :
133+ def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
153134 """Run when chain starts running."""
154135 pass
155136
156137 def on_chain_end (self , outputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
157138 """Run when chain ends running."""
158139 pass
159140
160- def on_chain_error (
161- self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
162- ) -> Any :
141+ def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
163142 """Run when chain errors."""
164143 pass
165144
166- def on_tool_start (
167- self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any
168- ) -> Any :
145+ def on_tool_start (self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any ) -> Any :
169146 """Run when tool starts running."""
170147 pass
171148
172149 def on_tool_end (self , output : str , ** kwargs : Any ) -> Any :
173150 """Run when tool ends running."""
174151 pass
175152
176- def on_tool_error (
177- self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
178- ) -> Any :
153+ def on_tool_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
179154 """Run when tool errors."""
180155 pass
181156
182157 def on_text (self , text : str , ** kwargs : Any ) -> Any :
183158 """Run on arbitrary text."""
184159 pass
185160
186- def on_agent_action (
187- self , action : langchain_schema .AgentAction , ** kwargs : Any
188- ) -> Any :
161+ def on_agent_action (self , action : langchain_schema .AgentAction , ** kwargs : Any ) -> Any :
189162 """Run on agent action."""
190163 pass
191164
192- def on_agent_finish (
193- self , finish : langchain_schema .AgentFinish , ** kwargs : Any
194- ) -> Any :
165+ def on_agent_finish (self , finish : langchain_schema .AgentFinish , ** kwargs : Any ) -> Any :
195166 """Run on agent end."""
196167 pass
0 commit comments