@@ -177,44 +177,118 @@ def _get_modified_create_chat_completion(self) -> callable:
177177 """Returns a modified version of the create method for openai.ChatCompletion."""
178178
179179 def modified_create_chat_completion (* args , ** kwargs ) -> str :
180- start_time = time .time ()
181- response = self .create_chat_completion (* args , ** kwargs )
182- latency = (time .time () - start_time ) * 1000
180+ stream = kwargs .get ("stream" , False )
183181
184- try :
185- # Extract data
186- prompt , input_data = self .format_input (kwargs ["messages" ])
187- output_data = response .choices [0 ].message .content .strip ()
188- num_of_tokens = response .usage .total_tokens
189- cost = self .get_cost_estimate (
190- model = kwargs .get ("model" ),
191- num_input_tokens = response .usage .prompt_tokens ,
192- num_output_tokens = response .usage .completion_tokens ,
193- )
182+ if not stream :
183+ start_time = time .time ()
184+ response = self .create_chat_completion (* args , ** kwargs )
185+ latency = (time .time () - start_time ) * 1000
194186
195- # Prepare config
196- config = self .data_config .copy ()
197- config ["prompt" ] = prompt
198- if not self .monitor_output_only :
199- config .update ({"inputVariableNames" : list (input_data .keys ())})
200-
201- self ._append_row_to_df (
202- input_data = input_data ,
203- output_data = output_data ,
204- num_of_tokens = num_of_tokens ,
205- latency = latency ,
206- cost = cost ,
207- )
187+ try :
188+ # Extract data
189+ prompt , input_data = self .format_input (kwargs ["messages" ])
190+ output_data = response .choices [0 ].message .content .strip ()
191+ num_of_tokens = response .usage .total_tokens
192+ cost = self .get_cost_estimate (
193+ model = kwargs .get ("model" ),
194+ num_input_tokens = response .usage .prompt_tokens ,
195+ num_output_tokens = response .usage .completion_tokens ,
196+ )
208197
209- self .data_streamer .stream_data (
210- data = self .df .tail (1 ).to_dict (orient = "records" ),
211- config = config ,
212- )
213- # pylint: disable=broad-except
214- except Exception as e :
215- logger .error ("Failed to monitor chat request. %s" , e )
198+ # Prepare config
199+ config = self .data_config .copy ()
200+ config ["prompt" ] = prompt
201+ if not self .monitor_output_only :
202+ config .update ({"inputVariableNames" : list (input_data .keys ())})
216203
217- return response
204+ self ._append_row_to_df (
205+ input_data = input_data ,
206+ output_data = output_data ,
207+ num_of_tokens = num_of_tokens ,
208+ latency = latency ,
209+ cost = cost ,
210+ )
211+
212+ self .data_streamer .stream_data (
213+ data = self .df .tail (1 ).to_dict (orient = "records" ),
214+ config = config ,
215+ )
216+ # pylint: disable=broad-except
217+ except Exception as e :
218+ logger .error ("Failed to monitor chat request. %s" , e )
219+
220+ return response
221+ else :
222+ chunks = self .create_chat_completion (* args , ** kwargs )
223+
224+ def stream_chunks ():
225+ collected_messages = []
226+ start_time = time .time ()
227+ first_token_time = None
228+ num_of_completion_tokens = None
229+ try :
230+ i = 0
231+ for i , chunk in enumerate (chunks ):
232+ if i == 0 :
233+ first_token_time = time .time ()
234+ collected_messages .append (chunk .choices [0 ].delta .content )
235+ yield chunk
236+ if i > 0 :
237+ num_of_completion_tokens = i + 1
238+ # pylint: disable=broad-except
239+ except Exception as e :
240+ logger .error ("Failed to monitor chat request. %s" , e )
241+ finally :
242+ try :
243+ # Extract data
244+ prompt , input_data = self .format_input (kwargs ["messages" ])
245+ collected_messages = [
246+ m for m in collected_messages if m is not None
247+ ]
248+ output_data = "" .join (collected_messages )
249+ completion_cost = self .get_cost_estimate (
250+ model = kwargs .get ("model" ),
251+ num_input_tokens = 0 ,
252+ num_output_tokens = (
253+ num_of_completion_tokens
254+ if num_of_completion_tokens
255+ else 0
256+ ),
257+ )
258+ latency = (time .time () - start_time ) * 1000
259+
260+ # Prepare config
261+ config = self .data_config .copy ()
262+ config ["prompt" ] = prompt
263+ if not self .monitor_output_only :
264+ config .update (
265+ {"inputVariableNames" : list (input_data .keys ())}
266+ )
267+
268+ self ._append_row_to_df (
269+ input_data = input_data ,
270+ output_data = output_data ,
271+ num_of_tokens = num_of_completion_tokens ,
272+ latency = latency ,
273+ cost = completion_cost ,
274+ time_to_first_token = (
275+ (first_token_time - start_time ) * 1000
276+ if first_token_time
277+ else None
278+ ),
279+ completion_tokens = num_of_completion_tokens ,
280+ completion_cost = completion_cost ,
281+ )
282+
283+ self .data_streamer .stream_data (
284+ data = self .df .tail (1 ).to_dict (orient = "records" ),
285+ config = config ,
286+ )
287+ # pylint: disable=broad-except
288+ except Exception as e :
289+ logger .error ("Failed to monitor chat request. %s" , e )
290+
291+ return stream_chunks ()
218292
219293 return modified_create_chat_completion
220294
@@ -348,9 +422,10 @@ def _append_row_to_df(
348422 self ,
349423 input_data : Dict [str , str ],
350424 output_data : str ,
351- num_of_tokens : int ,
352425 latency : float ,
426+ num_of_tokens : int ,
353427 cost : float ,
428+ ** kwargs ,
354429 ) -> None :
355430 """Appends a row with input/output, number of tokens, and latency to the
356431 df."""
@@ -367,6 +442,7 @@ def _append_row_to_df(
367442 "latency" : latency ,
368443 "cost" : cost ,
369444 },
445+ ** kwargs ,
370446 }
371447 ]
372448 )
0 commit comments