Skip to content

Commit 4ac7687

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5122 Support latencyColumnName and numOfTokensColumnName
1 parent b3d43af commit 4ac7687

File tree

2 files changed

+63
-1
lines changed

2 files changed

+63
-1
lines changed

openlayer/schemas.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -96,6 +96,11 @@ class BaseDatasetSchema(ma.Schema):
9696
load_default="en",
9797
validate=LANGUAGE_CODE_REGEX,
9898
)
99+
latencyColumnName = ma.fields.Str(
100+
validate=COLUMN_NAME_VALIDATION_LIST,
101+
allow_none=True,
102+
load_default=None,
103+
)
99104
metadata = ma.fields.Dict(allow_none=True, load_default={})
100105
sep = ma.fields.Str(load_default=",")
101106
timestampColumnName = ma.fields.Str(
@@ -185,6 +190,11 @@ class LLMOutputSchema(BaseDatasetSchema):
185190
groundTruthColumnName = ma.fields.Str(
186191
validate=COLUMN_NAME_VALIDATION_LIST, allow_none=True, load_default=None
187192
)
193+
numOfTokenColumnName = ma.fields.Str(
194+
validate=COLUMN_NAME_VALIDATION_LIST,
195+
allow_none=True,
196+
load_default=None,
197+
)
188198
outputColumnName = ma.fields.Str(
189199
validate=COLUMN_NAME_VALIDATION_LIST,
190200
allow_none=True,

openlayer/validators/dataset_validators.py

Lines changed: 53 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -198,11 +198,13 @@ def _validate_dataset_and_config_consistency(self):
198198
# Dataset-wide validations
199199
self._validate_dataset_dtypes()
200200

201-
# Timestamps and id validations
201+
# Timestamps, id, and latency validations
202202
if self.dataset_config.get("timestampColumnName"):
203203
self._validate_timestamps()
204204
if self.dataset_config.get("inferenceIdColumnName"):
205205
self._validate_inference_ids()
206+
if self.dataset_config.get("latencyColumnName"):
207+
self._validate_latencies()
206208

207209
self._validate_inputs()
208210
self._validate_outputs()
@@ -297,6 +299,35 @@ def _validate_inference_ids(self):
297299
"Please make sure that the inference ids are unique."
298300
)
299301

302+
def _validate_latencies(self):
303+
"""Checks if the latencies are in the correct format."""
304+
latency_column_name = self.dataset_config.get("latencyColumnName")
305+
if latency_column_name not in self.dataset_df.columns:
306+
self.failed_validations.append(
307+
f"The latency column `{latency_column_name}` specified as "
308+
"`latencyColumnName` is not in the dataset."
309+
)
310+
else:
311+
# Validate if values in the latency column are numbers (ints or floats)
312+
if not self._values_are_numbers(self.dataset_df, latency_column_name):
313+
self.failed_validations.append(
314+
f"The latencies in the column `{latency_column_name}` specified"
315+
" as `latencyColumnName` are not in the correct format. "
316+
"Please make sure that the dtype of the column with the latencies "
317+
"is one of int32, int64, float32, or float64."
318+
)
319+
320+
def _values_are_numbers(self, dataset_df: pd.DataFrame, column_name: str) -> bool:
321+
"""Checks whether the values in the column are numbers (ints or floats)."""
322+
if dataset_df[column_name].dtype.name in (
323+
"int64",
324+
"int32",
325+
"float32",
326+
"float64",
327+
):
328+
return True
329+
return False
330+
300331
@abstractmethod
301332
def _validate_inputs(self):
302333
"""To be implemented by InputValidator child classes."""
@@ -717,6 +748,7 @@ def _validate_outputs(self):
717748
"""Validates the LLM outputs (i.e., ground truth and output)."""
718749
self.ground_truth_column_name = self.dataset_config.get("groundTruthColumnName")
719750
self.output_column_name = self.dataset_config.get("outputColumnName")
751+
self.num_of_token_column_name = self.dataset_config.get("numOfTokenColumnName")
720752

721753
if self.ground_truth_column_name:
722754
self._validate_ground_truth()
@@ -727,6 +759,9 @@ def _validate_outputs(self):
727759
if self.ground_truth_column_name and self.output_column_name:
728760
self._validate_ground_truth_and_output_columns_different()
729761

762+
if self.num_of_token_column_name:
763+
self._validate_num_of_token()
764+
730765
def _validate_ground_truth(self):
731766
"""Validations on the ground truth column."""
732767
if self.ground_truth_column_name not in self.dataset_df.columns:
@@ -773,6 +808,23 @@ def _validate_ground_truth_and_output_columns_different(self):
773808
"Please specify different columns for the output and the ground truths."
774809
)
775810

811+
def _validate_num_of_token(self):
812+
"""Validates the number of tokens column."""
813+
if self.num_of_token_column_name not in self.dataset_df.columns:
814+
self.failed_validations.append(
815+
f"The number of tokens column `{self.num_of_token_column_name}` "
816+
"specified as `numOfTokenColumnName` is not in the dataset."
817+
)
818+
elif not self._values_are_numbers(
819+
self.dataset_df, self.num_of_token_column_name
820+
):
821+
self.failed_validations.append(
822+
f"The number of tokens in the column `{self.num_of_token_column_name}`"
823+
" specified as `numOfTokenColumnName` are not in the correct format. "
824+
"Please make sure that the dtype of the column with the number of"
825+
" tokens is one of int32, int64, float32, or float64."
826+
)
827+
776828

777829
class RegressionOutputValidator(BaseDatasetValidator):
778830
"""Validates regression outputs.

0 commit comments

Comments
 (0)