22
33import logging
44import time
5- from typing import Dict , List , Optional
5+ from typing import Dict , List , Optional , Tuple
66
77import openai
88import pandas as pd
@@ -204,9 +204,12 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
204204 latency = (time .time () - start_time ) * 1000
205205
206206 try :
207- input_data = self ._format_user_messages (kwargs ["messages" ])
207+ prompt , input_data = self .format_input (kwargs ["messages" ])
208208 output_data = response .choices [0 ].message .content .strip ()
209209 num_of_tokens = response .usage .total_tokens
210+ config = self .data_config .copy ()
211+ config ["prompt" ] = prompt
212+ config .update ({"inputVariableNames" : list (input_data .keys ())})
210213
211214 self ._append_row_to_df (
212215 input_data = input_data ,
@@ -215,10 +218,10 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
215218 latency = latency ,
216219 )
217220
218- self ._handle_data_publishing ()
221+ self ._handle_data_publishing (config = config )
219222 # pylint: disable=broad-except
220223 except Exception as e :
221- logger .error ("Failed to track chat request. %s" , e )
224+ logger .error ("Failed to monitor chat request. %s" , e )
222225
223226 return response
224227
@@ -242,7 +245,7 @@ def modified_create_completion(*args, **kwargs):
242245 num_of_tokens = int (response .usage .total_tokens / len (prompts ))
243246
244247 self ._append_row_to_df (
245- input_data = input_data ,
248+ input_data = { "message" : input_data } ,
246249 output_data = output_data ,
247250 num_of_tokens = num_of_tokens ,
248251 latency = latency ,
@@ -251,19 +254,52 @@ def modified_create_completion(*args, **kwargs):
251254 self ._handle_data_publishing ()
252255 # pylint: disable=broad-except
253256 except Exception as e :
254- logger .error ("Failed to track completion request. %s" , e )
257+ logger .error ("Failed to monitor completion request. %s" , e )
255258
256259 return response
257260
258261 return modified_create_completion
259262
260263 @staticmethod
261- def _format_user_messages (conversation_list : List [Dict [str , str ]]) -> str :
262- """Extracts the 'user' messages from the conversation list and returns them
263- as a single string."""
264- return "\n " .join (
265- item ["content" ] for item in conversation_list if item ["role" ] == "user"
266- ).strip ()
264+ def format_input (
265+ messages : List [Dict [str , str ]]
266+ ) -> Tuple [List [Dict [str , str ]], Dict [str , str ]]:
267+ """Formats the input messages.
268+
269+ Returns messages (prompt) replacing the user messages with input variables
270+ in brackets (e.g., ``{{ message_0 }}``) and a dictionary mapping the input variable
271+ names to the original user messages.
272+
273+ Parameters
274+ ----------
275+ messages : List[Dict[str, str]]
276+ List of messages that were sent to the chat completion model. Each message
277+ is a dictionary with the following keys:
278+
279+ - ``role``: The role of the message. Can be either ``"user"`` or ``"system"``.
280+ - ``content``: The content of the message.
281+
282+ Returns
283+ -------
284+ Tuple(List[Dict[str, str]], Dict[str, str])
285+ The formatted messages and the mapping from input variable names to the
286+ original user messages.
287+ """
288+ input_messages = []
289+ input_variables = {}
290+ for i , message in enumerate (messages ):
291+ if message ["role" ] == "user" :
292+ input_variable_name = f"message_{ i } "
293+ input_messages .append (
294+ {
295+ "role" : message ["role" ],
296+ "content" : f"{{{{ { input_variable_name } }}}}" ,
297+ }
298+ )
299+ input_variables [input_variable_name ] = message ["content" ]
300+ else :
301+ input_messages .append (message )
302+ return input_messages , input_variables
267303
268304 @staticmethod
269305 def _split_list (lst : List , n_parts : int ) -> List [List ]:
@@ -288,37 +324,46 @@ def _split_list(lst: List, n_parts: int) -> List[List]:
288324 return result
289325
290326 def _append_row_to_df (
291- self , input_data : str , output_data : str , num_of_tokens : int , latency : float
327+ self ,
328+ input_data : Dict [str , str ],
329+ output_data : str ,
330+ num_of_tokens : int ,
331+ latency : float ,
292332 ) -> None :
293333 """Appends a row with input/output, number of tokens, and latency to the
294334 df."""
295335 row = pd .DataFrame (
296336 [
297337 {
298- "input" : input_data ,
299- "output" : output_data ,
300- "tokens" : num_of_tokens ,
301- "latency" : latency ,
338+ ** input_data ,
339+ ** {
340+ "output" : output_data ,
341+ "tokens" : num_of_tokens ,
342+ "latency" : latency ,
343+ },
302344 }
303345 ]
304346 )
305347 if self .accumulate_data :
306348 self .df = pd .concat ([self .df , row ], ignore_index = True )
307349 else :
308350 self .df = row
309- self .df = self .df .astype (
310- {"input" : object , "output" : object , "tokens" : int , "latency" : float }
311- )
312351
313- def _handle_data_publishing (self ) -> None :
352+ # Perform casting
353+ input_columns = [col for col in self .df .columns if col .startswith ("message" )]
354+ casting_dict = {col : object for col in input_columns }
355+ casting_dict .update ({"output" : object , "tokens" : int , "latency" : float })
356+ self .df = self .df .astype (casting_dict )
357+
358+ def _handle_data_publishing (self , config : Optional [Dict [str , any ]] = None ) -> None :
314359 """Handle data publishing.
315360
316361 If `publish` is set to True, publish the latest row to Openlayer.
317362 """
318363 if self .publish :
319364 self .inference_pipeline .stream_data (
320365 stream_data = self .df .tail (1 ).to_dict (orient = "records" ),
321- stream_config = self .data_config ,
366+ stream_config = config or self .data_config ,
322367 )
323368
324369 def start_monitoring (self ) -> None :
@@ -411,7 +456,7 @@ def publish_batch_data(self):
411456 def data_config (self ) -> Dict [str , any ]:
412457 """Data config for the df. Used for publishing data to Openlayer."""
413458 return {
414- "inputVariableNames" : ["input " ],
459+ "inputVariableNames" : ["message " ],
415460 "label" : "production" ,
416461 "outputColumnName" : "output" ,
417462 "numOfTokenColumnName" : "tokens" ,
0 commit comments