@@ -35,7 +35,9 @@ def __init__(self, **kwargs: Any) -> None:
3535 self .metatada : Dict [str , Any ] = kwargs or {}
3636
3737 # noqa arg002
38- def on_llm_start (self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any ) -> Any :
38+ def on_llm_start (
39+ self , serialized : Dict [str , Any ], prompts : List [str ], ** kwargs : Any
40+ ) -> Any :
3941 """Run when LLM starts running."""
4042 pass
4143
@@ -79,32 +81,45 @@ def on_llm_new_token(self, token: str, **kwargs: Any) -> Any:
7981 """Run on new LLM token. Only available when streaming is enabled."""
8082 pass
8183
82- def on_llm_end (self , response : langchain_schema .LLMResult , ** kwargs : Any ) -> Any : # noqa: ARG002, E501
84+ def on_llm_end (
85+ self , response : langchain_schema .LLMResult , ** kwargs : Any # noqa: ARG002, E501
86+ ) -> Any :
8387 """Run when LLM ends running."""
8488 self .end_time = time .time ()
8589 self .latency = (self .end_time - self .start_time ) * 1000
8690
8791 if response .llm_output and "token_usage" in response .llm_output :
88- self .prompt_tokens = response .llm_output ["token_usage" ].get ("prompt_tokens" , 0 )
89- self .completion_tokens = response .llm_output ["token_usage" ].get ("completion_tokens" , 0 )
92+ self .prompt_tokens = response .llm_output ["token_usage" ].get (
93+ "prompt_tokens" , 0
94+ )
95+ self .completion_tokens = response .llm_output ["token_usage" ].get (
96+ "completion_tokens" , 0
97+ )
9098 self .cost = self ._get_cost_estimate (
9199 num_input_tokens = self .prompt_tokens ,
92100 num_output_tokens = self .completion_tokens ,
93101 )
94- self .total_tokens = response .llm_output ["token_usage" ].get ("total_tokens" , 0 )
102+ self .total_tokens = response .llm_output ["token_usage" ].get (
103+ "total_tokens" , 0
104+ )
95105
96106 for generations in response .generations :
97107 for generation in generations :
98108 self .output += generation .text .replace ("\n " , " " )
99109
100110 self ._add_to_trace ()
101111
102- def _get_cost_estimate (self , num_input_tokens : int , num_output_tokens : int ) -> float :
112+ def _get_cost_estimate (
113+ self , num_input_tokens : int , num_output_tokens : int
114+ ) -> float :
103115 """Returns the cost estimate for a given model and number of tokens."""
104116 if self .model not in constants .OPENAI_COST_PER_TOKEN :
105117 return None
106118 cost_per_token = constants .OPENAI_COST_PER_TOKEN [self .model ]
107- return cost_per_token ["input" ] * num_input_tokens + cost_per_token ["output" ] * num_output_tokens
119+ return (
120+ cost_per_token ["input" ] * num_input_tokens
121+ + cost_per_token ["output" ] * num_output_tokens
122+ )
108123
109124 def _add_to_trace (self ) -> None :
110125 """Adds to the trace."""
@@ -126,42 +141,56 @@ def _add_to_trace(self) -> None:
126141 metadata = self .metatada ,
127142 )
128143
129- def on_llm_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
144+ def on_llm_error (
145+ self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
146+ ) -> Any :
130147 """Run when LLM errors."""
131148 pass
132149
133- def on_chain_start (self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
150+ def on_chain_start (
151+ self , serialized : Dict [str , Any ], inputs : Dict [str , Any ], ** kwargs : Any
152+ ) -> Any :
134153 """Run when chain starts running."""
135154 pass
136155
137156 def on_chain_end (self , outputs : Dict [str , Any ], ** kwargs : Any ) -> Any :
138157 """Run when chain ends running."""
139158 pass
140159
141- def on_chain_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
160+ def on_chain_error (
161+ self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
162+ ) -> Any :
142163 """Run when chain errors."""
143164 pass
144165
145- def on_tool_start (self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any ) -> Any :
166+ def on_tool_start (
167+ self , serialized : Dict [str , Any ], input_str : str , ** kwargs : Any
168+ ) -> Any :
146169 """Run when tool starts running."""
147170 pass
148171
149172 def on_tool_end (self , output : str , ** kwargs : Any ) -> Any :
150173 """Run when tool ends running."""
151174 pass
152175
153- def on_tool_error (self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any ) -> Any :
176+ def on_tool_error (
177+ self , error : Union [Exception , KeyboardInterrupt ], ** kwargs : Any
178+ ) -> Any :
154179 """Run when tool errors."""
155180 pass
156181
157182 def on_text (self , text : str , ** kwargs : Any ) -> Any :
158183 """Run on arbitrary text."""
159184 pass
160185
161- def on_agent_action (self , action : langchain_schema .AgentAction , ** kwargs : Any ) -> Any :
186+ def on_agent_action (
187+ self , action : langchain_schema .AgentAction , ** kwargs : Any
188+ ) -> Any :
162189 """Run on agent action."""
163190 pass
164191
165- def on_agent_finish (self , finish : langchain_schema .AgentFinish , ** kwargs : Any ) -> Any :
192+ def on_agent_finish (
193+ self , finish : langchain_schema .AgentFinish , ** kwargs : Any
194+ ) -> Any :
166195 """Run on agent end."""
167196 pass
0 commit comments