Skip to content

Commit 2cbd1eb

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Add support for OpenAI assistants
1 parent f939d97 commit 2cbd1eb

File tree

3 files changed

+102
-1
lines changed

3 files changed

+102
-1
lines changed

CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.
88
## Unreleased
99

1010
### Added
11+
* Added support for OpenAI assistants. The `llm_monitor` now supports monitoring OpenAI assistant runs with the function `monitor_run`.
1112
* Added the ability to use the `llm_monitor.OpenAIMonitor` as a context manager.
1213
* Added `openlayer_inference_pipeline_id` as an optional parameter to the `OpenAIMonitor`. This is an alternative to `openlayer_inference_pipeline_name` and `openlayer_inference_project_name` parameters for identifying the inference pipeline on the platform.
1314
* Added `monitor_output_only` as an argument to the OpenAI `llm_monitor`. If set to `True`, the monitor will only record the output of the model, and not the input.

openlayer/llm_monitors.py

Lines changed: 98 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -477,3 +477,101 @@ def data_config(self) -> Dict[str, any]:
477477
def data(self) -> pd.DataFrame:
478478
"""Dataframe accumulated after monitoring was switched on."""
479479
return self.df
480+
481+
def monitor_thread_run(self, run: openai.types.beta.threads.run.Run) -> None:
482+
"""Monitor a run from an OpenAI assistant.
483+
484+
Once the run is completed, the thread data is published to Openlayer,
485+
along with the latency, cost, and number of tokens used."""
486+
self._type_check_run(run)
487+
488+
# Do nothing if the run is not completed
489+
if run.status != "completed":
490+
return
491+
492+
try:
493+
# Extract vars
494+
run_vars = self._extract_run_vars(run)
495+
496+
# Convert thread to prompt
497+
messages = self.openai_client.beta.threads.messages.list(
498+
thread_id=run_vars["openai_thread_id"], order="asc"
499+
)
500+
populated_prompt = self.thread_messages_to_prompt(messages)
501+
prompt, input_variables = self.format_input(populated_prompt)
502+
503+
# Data
504+
input_data = {
505+
**input_variables,
506+
**{
507+
"output": prompt[-1]["content"],
508+
"tokens": run_vars["total_num_tokens"],
509+
"latency": run_vars["latency"],
510+
"cost": run_vars["cost"],
511+
"thread_id": run_vars["openai_thread_id"],
512+
"assistant_id": run_vars["openai_assistant_id"],
513+
"timestamp": run_vars["timestamp"],
514+
},
515+
}
516+
517+
# Config
518+
config = self.data_config.copy()
519+
config["inputVariableNames"] = input_variables.keys()
520+
config["prompt"] = prompt[:-1] # Remove the last message (the output)
521+
config["timestampColumnName"] = "timestamp"
522+
523+
self.data_streamer.stream_data(data=input_data, config=config)
524+
print("Data published to Openlayer.")
525+
# pylint: disable=broad-except
526+
except Exception as e:
527+
print(f"Failed to monitor run. {e}")
528+
529+
def _type_check_run(self, run: openai.types.beta.threads.run.Run) -> None:
530+
"""Validate the run object."""
531+
if not isinstance(run, openai.types.beta.threads.run.Run):
532+
raise ValueError(f"Expected a Run object, but got {type(run)}.")
533+
534+
def _extract_run_vars(
535+
self, run: openai.types.beta.threads.run.Run
536+
) -> Dict[str, any]:
537+
"""Extract the variables from the run object."""
538+
return {
539+
"openai_thread_id": run.thread_id,
540+
"openai_assistant_id": run.assistant_id,
541+
"latency": (run.completed_at - run.created_at) * 1000, # Convert to ms
542+
"timestamp": run.created_at, # Convert to ms
543+
"num_input_tokens": run.usage["prompt_tokens"],
544+
"num_output_tokens": run.usage["completion_tokens"],
545+
"total_num_tokens": run.usage["total_tokens"],
546+
"cost": self.get_cost_estimate(
547+
model=run.model,
548+
num_input_tokens=run.usage["prompt_tokens"],
549+
num_output_tokens=run.usage["completion_tokens"],
550+
),
551+
}
552+
553+
@staticmethod
554+
def thread_messages_to_prompt(
555+
messages: List[openai.types.beta.threads.thread_message.ThreadMessage],
556+
) -> List[Dict[str, str]]:
557+
"""Given list of ThreadMessage, return its contents in the `prompt` format,
558+
i.e., a list of dicts with 'role' and 'content' keys."""
559+
prompt = []
560+
for message in list(messages):
561+
role = message.role
562+
contents = message.content
563+
564+
for content in contents:
565+
content_type = content.type
566+
if content_type == "text":
567+
text_content = content.text.value
568+
if content_type == "image_file":
569+
text_content = content.image_file.file_id
570+
571+
prompt.append(
572+
{
573+
"role": role,
574+
"content": text_content,
575+
}
576+
)
577+
return prompt

openlayer/services/data_streamer.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,9 @@ def _load_inference_pipeline(self) -> None:
129129
name=self.openlayer_project_name,
130130
task_type=tasks.TaskType.LLM,
131131
)
132-
inference_pipeline = project.create_inference_pipeline()
132+
inference_pipeline = project.create_inference_pipeline(
133+
name=self.openlayer_inference_pipeline_name
134+
)
133135

134136
self.inference_pipeline = inference_pipeline
135137

0 commit comments

Comments
 (0)