Skip to content

Commit 8ec1410

Browse files
committed
Closes OPEN-5836 Tracing improvements to get Openlayer Assistant traced
1 parent 15b871d commit 8ec1410

18 files changed

+116
-32
lines changed

openlayer/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
project.status()
2222
project.push()
2323
"""
24+
2425
import os
2526
import shutil
2627
import tarfile

openlayer/api.py

Lines changed: 7 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
project_data = self.api.post_request(endpoint, body=payload)
2020
2121
"""
22+
2223
import os
2324
import shutil
2425
from enum import Enum
@@ -174,9 +175,9 @@ def post_request(
174175
return self._api_request(
175176
method="POST",
176177
endpoint=endpoint,
177-
headers=self._headers
178-
if files is None
179-
else self._headers_multipart_form_data,
178+
headers=(
179+
self._headers if files is None else self._headers_multipart_form_data
180+
),
180181
body=body,
181182
files=files,
182183
data=data,
@@ -188,9 +189,9 @@ def put_request(self, endpoint: str, body=None, files=None, data=None):
188189
return self._api_request(
189190
"PUT",
190191
endpoint,
191-
headers=self._headers
192-
if files is None
193-
else self._headers_multipart_form_data,
192+
headers=(
193+
self._headers if files is None else self._headers_multipart_form_data
194+
),
194195
body=body,
195196
files=files,
196197
data=data,

openlayer/constants.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module for storing constants used throughout the OpenLayer Python Client.
22
"""
3+
34
import os
45

56
import marshmallow as ma

openlayer/exceptions.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
if project is None:
66
raise errors.OpenlayerResourceNotFound(f"Project {project_id} does not exist")
77
"""
8+
89
from typing import Dict
910

1011

openlayer/llm_monitors.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
"""Module with classes for monitoring calls to LLMs."""
22

3+
import json
34
import logging
45
import time
56
import warnings
@@ -149,17 +150,18 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
149150
output_data = output_content.strip()
150151
elif output_function_call or output_tool_calls:
151152
if output_function_call:
152-
function_call = dict(output_function_call)
153+
function_call = {
154+
"name": output_function_call.name,
155+
"arguments": json.loads(output_function_call.arguments),
156+
}
153157
else:
154-
function_call = dict(output_tool_calls[0].function)
155-
metadata = {
156-
"function_call_name": function_call.get("name"),
157-
"function_call_arguments": function_call.get("arguments"),
158-
}
159-
output_data = str(function_call)
158+
function_call = {
159+
"name": output_tool_calls[0].name,
160+
"arguments": json.loads(output_function_call.arguments),
161+
}
162+
output_data = function_call
160163
else:
161164
output_data = None
162-
output_data = response.choices[0].message.content
163165
cost = self.get_cost_estimate(
164166
model=kwargs.get("model"),
165167
num_input_tokens=response.usage.prompt_tokens,
@@ -181,7 +183,6 @@ def modified_create_chat_completion(*args, **kwargs) -> str:
181183
model_parameters=kwargs.get("model_parameters"),
182184
raw_output=response.model_dump(),
183185
provider="OpenAI",
184-
metadata=metadata,
185186
)
186187
# pylint: disable=broad-except
187188
except Exception as e:

openlayer/model_runners/prediction_jobs/classification_prediction_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Example usage:
99
python classification_prediction_job.py --input /path/to/input.csv --output /path/to/output.csv
1010
"""
11+
1112
import argparse
1213
import logging
1314

openlayer/model_runners/prediction_jobs/regression_prediction_job.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
Example usage:
99
python regression_prediction_job.py --input /path/to/input.csv --output /path/to/output.csv
1010
"""
11+
1112
import argparse
1213
import logging
1314

openlayer/model_runners/tests/test_llm_runners.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@
44
55
pytest test_llm_runners.py
66
"""
7+
78
from typing import Dict
89

910
import anthropic

openlayer/services/data_streamer.py

Lines changed: 46 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -29,28 +29,67 @@ def __init__(
2929
openlayer_inference_pipeline_id: Optional[str] = None,
3030
publish: bool = False,
3131
) -> None:
32-
self.openlayer_api_key = openlayer_api_key or utils.get_env_variable(
32+
self._openlayer_api_key = openlayer_api_key or utils.get_env_variable(
3333
"OPENLAYER_API_KEY"
3434
)
35-
self.openlayer_project_name = openlayer_project_name or utils.get_env_variable(
35+
self._openlayer_project_name = openlayer_project_name or utils.get_env_variable(
3636
"OPENLAYER_PROJECT_NAME"
3737
)
38-
self.openlayer_inference_pipeline_name = (
38+
self._openlayer_inference_pipeline_name = (
3939
openlayer_inference_pipeline_name
4040
or utils.get_env_variable("OPENLAYER_INFERENCE_PIPELINE_NAME")
4141
or "production"
4242
)
43-
self.openlayer_inference_pipeline_id = (
43+
self._openlayer_inference_pipeline_id = (
4444
openlayer_inference_pipeline_id
4545
or utils.get_env_variable("OPENLAYER_INFERENCE_PIPELINE_ID")
4646
)
4747
self.publish = publish
4848

49-
self._validate_attributes()
50-
5149
# Lazy load the inference pipeline
5250
self.inference_pipeline = None
5351

52+
@property
53+
def openlayer_api_key(self) -> Optional[str]:
54+
"""The Openlayer API key."""
55+
return self._get_openlayer_attribute("_openlayer_api_key", "OPENLAYER_API_KEY")
56+
57+
@property
58+
def openlayer_project_name(self) -> Optional[str]:
59+
"""The name of the project on Openlayer."""
60+
return self._get_openlayer_attribute(
61+
"_openlayer_project_name", "OPENLAYER_PROJECT_NAME"
62+
)
63+
64+
@property
65+
def openlayer_inference_pipeline_name(self) -> Optional[str]:
66+
"""The name of the inference pipeline on Openlayer."""
67+
return self._get_openlayer_attribute(
68+
"_openlayer_inference_pipeline_name", "OPENLAYER_INFERENCE_PIPELINE_NAME"
69+
)
70+
71+
@property
72+
def openlayer_inference_pipeline_id(self) -> Optional[str]:
73+
"""The id of the inference pipeline on Openlayer."""
74+
return self._get_openlayer_attribute(
75+
"_openlayer_inference_pipeline_id", "OPENLAYER_INFERENCE_PIPELINE_ID"
76+
)
77+
78+
def _get_openlayer_attribute(
79+
self, attribute_name: str, env_variable: str
80+
) -> Optional[str]:
81+
"""A helper method to fetch an Openlayer attribute value.
82+
83+
Args:
84+
attribute_name: The name of the attribute in this class.
85+
env_variable: The name of the environment variable to fetch.
86+
"""
87+
attribute_value = getattr(self, attribute_name, None)
88+
if not attribute_value:
89+
attribute_value = utils.get_env_variable(env_variable)
90+
setattr(self, attribute_name, attribute_value)
91+
return attribute_value
92+
5493
def _validate_attributes(self) -> None:
5594
"""Granular validation of the arguments."""
5695
if self.publish:
@@ -97,6 +136,7 @@ def stream_data(self, data: Dict[str, any], config: Dict[str, any]) -> None:
97136
config: The configuration for the data stream.
98137
"""
99138

139+
self._validate_attributes()
100140
self._check_inference_pipeline_ready()
101141
self.inference_pipeline.stream_data(stream_data=data, stream_config=config)
102142
logger.info("Data streamed to Openlayer.")

openlayer/tracing/steps.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import time
44
from typing import Any, Dict, Optional
55

6+
from .. import utils
67
from . import enums
78

89

@@ -40,6 +41,7 @@ def add_nested_step(self, nested_step: "Step") -> None:
4041

4142
def log(self, **kwargs: Any) -> None:
4243
"""Logs step data."""
44+
kwargs = utils.json_serialize(kwargs)
4345
for key, value in kwargs.items():
4446
if hasattr(self, key):
4547
setattr(self, key, value)

0 commit comments

Comments
 (0)