Skip to content

Commit 865b3a3

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Include prompt in config for OpenAI ChatCompletion monitor
1 parent 71edc0f commit 865b3a3

File tree

1 file changed

+68
-23
lines changed

1 file changed

+68
-23
lines changed

openlayer/llm_monitors.py

Lines changed: 68 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
import logging
44
import time
5-
from typing import Dict, List, Optional
5+
from typing import Dict, List, Optional, Tuple
66

77
import openai
88
import pandas as pd
@@ -204,9 +204,12 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
204204
latency = (time.time() - start_time) * 1000
205205

206206
try:
207-
input_data = self._format_user_messages(kwargs["messages"])
207+
prompt, input_data = self.format_input(kwargs["messages"])
208208
output_data = response.choices[0].message.content.strip()
209209
num_of_tokens = response.usage.total_tokens
210+
config = self.data_config.copy()
211+
config["prompt"] = prompt
212+
config.update({"inputVariableNames": list(input_data.keys())})
210213

211214
self._append_row_to_df(
212215
input_data=input_data,
@@ -215,10 +218,10 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
215218
latency=latency,
216219
)
217220

218-
self._handle_data_publishing()
221+
self._handle_data_publishing(config=config)
219222
# pylint: disable=broad-except
220223
except Exception as e:
221-
logger.error("Failed to track chat request. %s", e)
224+
logger.error("Failed to monitor chat request. %s", e)
222225

223226
return response
224227

@@ -242,7 +245,7 @@ def modified_create_completion(*args, **kwargs):
242245
num_of_tokens = int(response.usage.total_tokens / len(prompts))
243246

244247
self._append_row_to_df(
245-
input_data=input_data,
248+
input_data={"message": input_data},
246249
output_data=output_data,
247250
num_of_tokens=num_of_tokens,
248251
latency=latency,
@@ -251,19 +254,52 @@ def modified_create_completion(*args, **kwargs):
251254
self._handle_data_publishing()
252255
# pylint: disable=broad-except
253256
except Exception as e:
254-
logger.error("Failed to track completion request. %s", e)
257+
logger.error("Failed to monitor completion request. %s", e)
255258

256259
return response
257260

258261
return modified_create_completion
259262

260263
@staticmethod
261-
def _format_user_messages(conversation_list: List[Dict[str, str]]) -> str:
262-
"""Extracts the 'user' messages from the conversation list and returns them
263-
as a single string."""
264-
return "\n".join(
265-
item["content"] for item in conversation_list if item["role"] == "user"
266-
).strip()
264+
def format_input(
265+
messages: List[Dict[str, str]]
266+
) -> Tuple[List[Dict[str, str]], Dict[str, str]]:
267+
"""Formats the input messages.
268+
269+
Returns messages (prompt) replacing the user messages with input variables
270+
in brackets (e.g., ``{{ message_0 }}``) and a dictionary mapping the input variable
271+
names to the original user messages.
272+
273+
Parameters
274+
----------
275+
messages : List[Dict[str, str]]
276+
List of messages that were sent to the chat completion model. Each message
277+
is a dictionary with the following keys:
278+
279+
- ``role``: The role of the message. Can be either ``"user"`` or ``"system"``.
280+
- ``content``: The content of the message.
281+
282+
Returns
283+
-------
284+
Tuple(List[Dict[str, str]], Dict[str, str])
285+
The formatted messages and the mapping from input variable names to the
286+
original user messages.
287+
"""
288+
input_messages = []
289+
input_variables = {}
290+
for i, message in enumerate(messages):
291+
if message["role"] == "user":
292+
input_variable_name = f"message_{i}"
293+
input_messages.append(
294+
{
295+
"role": message["role"],
296+
"content": f"{{{{ {input_variable_name} }}}}",
297+
}
298+
)
299+
input_variables[input_variable_name] = message["content"]
300+
else:
301+
input_messages.append(message)
302+
return input_messages, input_variables
267303

268304
@staticmethod
269305
def _split_list(lst: List, n_parts: int) -> List[List]:
@@ -288,37 +324,46 @@ def _split_list(lst: List, n_parts: int) -> List[List]:
288324
return result
289325

290326
def _append_row_to_df(
291-
self, input_data: str, output_data: str, num_of_tokens: int, latency: float
327+
self,
328+
input_data: Dict[str, str],
329+
output_data: str,
330+
num_of_tokens: int,
331+
latency: float,
292332
) -> None:
293333
"""Appends a row with input/output, number of tokens, and latency to the
294334
df."""
295335
row = pd.DataFrame(
296336
[
297337
{
298-
"input": input_data,
299-
"output": output_data,
300-
"tokens": num_of_tokens,
301-
"latency": latency,
338+
**input_data,
339+
**{
340+
"output": output_data,
341+
"tokens": num_of_tokens,
342+
"latency": latency,
343+
},
302344
}
303345
]
304346
)
305347
if self.accumulate_data:
306348
self.df = pd.concat([self.df, row], ignore_index=True)
307349
else:
308350
self.df = row
309-
self.df = self.df.astype(
310-
{"input": object, "output": object, "tokens": int, "latency": float}
311-
)
312351

313-
def _handle_data_publishing(self) -> None:
352+
# Perform casting
353+
input_columns = [col for col in self.df.columns if col.startswith("message")]
354+
casting_dict = {col: object for col in input_columns}
355+
casting_dict.update({"output": object, "tokens": int, "latency": float})
356+
self.df = self.df.astype(casting_dict)
357+
358+
def _handle_data_publishing(self, config: Optional[Dict[str, any]] = None) -> None:
314359
"""Handle data publishing.
315360
316361
If `publish` is set to True, publish the latest row to Openlayer.
317362
"""
318363
if self.publish:
319364
self.inference_pipeline.stream_data(
320365
stream_data=self.df.tail(1).to_dict(orient="records"),
321-
stream_config=self.data_config,
366+
stream_config=config or self.data_config,
322367
)
323368

324369
def start_monitoring(self) -> None:
@@ -411,7 +456,7 @@ def publish_batch_data(self):
411456
def data_config(self) -> Dict[str, any]:
412457
"""Data config for the df. Used for publishing data to Openlayer."""
413458
return {
414-
"inputVariableNames": ["input"],
459+
"inputVariableNames": ["message"],
415460
"label": "production",
416461
"outputColumnName": "output",
417462
"numOfTokenColumnName": "tokens",

0 commit comments

Comments
 (0)