Skip to content

Commit e0bb27c

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5038 Reference datasets should have label 'reference'
1 parent 63ff1fe commit e0bb27c

File tree

4 files changed

+128
-12
lines changed

4 files changed

+128
-12
lines changed

openlayer/__init__.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1723,11 +1723,13 @@ def upload_reference_dataset(
17231723
"Either `dataset_config` or `dataset_config_file_path` must be"
17241724
" provided."
17251725
)
1726+
if dataset_config_file_path is not None:
1727+
dataset_config = utils.read_yaml(dataset_config_file_path)
1728+
dataset_config["label"] = "reference"
17261729

17271730
# Validate dataset
17281731
dataset_validator = dataset_validators.get_validator(
17291732
task_type=task_type,
1730-
dataset_config_file_path=dataset_config_file_path,
17311733
dataset_config=dataset_config,
17321734
dataset_file_path=file_path,
17331735
)
@@ -1740,8 +1742,6 @@ def upload_reference_dataset(
17401742
) from None
17411743

17421744
# Load dataset config and augment with defaults
1743-
if dataset_config_file_path is not None:
1744-
dataset_config = utils.read_yaml(dataset_config_file_path)
17451745
dataset_data = DatasetSchema().load(
17461746
{"task_type": task_type.value, **dataset_config}
17471747
)

openlayer/datasets.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,8 @@ class DatasetType(Enum):
2828
Training = "training"
2929
#: For production data.
3030
Production = "production"
31+
#: For reference datasets.
32+
Reference = "reference"
3133

3234

3335
class Dataset:

openlayer/schemas.py

Lines changed: 77 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -210,7 +210,7 @@ class RegressionOutputSchema(BaseDatasetSchema):
210210
class LLMDatasetSchema(LLMInputSchema, LLMOutputSchema):
211211
"""LLM dataset schema."""
212212

213-
# Override the label to allow for a 'fine-tuning' label instead
213+
# Overwrite the label to allow for a 'fine-tuning' label instead
214214
# of the 'training' label
215215
label = ma.fields.Str(
216216
validate=ma.validate.OneOf(
@@ -490,3 +490,79 @@ class ProjectSchema(ma.Schema):
490490
+ " https://reference.openlayer.com/reference/api/openlayer.TaskType.html.\n ",
491491
),
492492
)
493+
494+
495+
# ---------------------------- Reference datasets ---------------------------- #
496+
class LLMReferenceDatasetSchema(LLMDatasetSchema):
497+
"""LLM reference dataset schema."""
498+
499+
# Overwrite the label to allow for a 'reference' label instead
500+
label = ma.fields.Str(
501+
validate=ma.validate.OneOf(
502+
[DatasetType.Reference.value],
503+
error="`label` not supported." + "The supported `labels` are 'reference'.",
504+
),
505+
required=True,
506+
)
507+
508+
509+
class TabularClassificationReferenceDatasetSchema(TabularClassificationDatasetSchema):
510+
"""Tabular classification reference dataset schema."""
511+
512+
# Overwrite the label to allow for a 'reference' label instead
513+
label = ma.fields.Str(
514+
validate=ma.validate.OneOf(
515+
[DatasetType.Reference.value],
516+
error="`label` not supported." + "The supported `labels` are 'reference'.",
517+
),
518+
required=True,
519+
)
520+
521+
522+
class TabularRegressionReferenceDatasetSchema(TabularRegressionDatasetSchema):
523+
"""Tabular regression reference dataset schema."""
524+
525+
# Overwrite the label to allow for a 'reference' label instead
526+
label = ma.fields.Str(
527+
validate=ma.validate.OneOf(
528+
[DatasetType.Reference.value],
529+
error="`label` not supported." + "The supported `labels` are 'reference'.",
530+
),
531+
required=True,
532+
)
533+
534+
535+
class TextClassificationReferenceDatasetSchema(TextClassificationDatasetSchema):
536+
"""Text classification reference dataset schema."""
537+
538+
# Overwrite the label to allow for a 'reference' label instead
539+
label = ma.fields.Str(
540+
validate=ma.validate.OneOf(
541+
[DatasetType.Reference.value],
542+
error="`label` not supported." + "The supported `labels` are 'reference'.",
543+
),
544+
required=True,
545+
)
546+
547+
548+
class ReferenceDatasetSchema(maos.OneOfSchema):
549+
"""One of schema for reference datasets.
550+
Returns the correct schema based on the task type."""
551+
552+
type_field = "task_type"
553+
# pylint: ignore=line-too-long
554+
type_schemas = {
555+
TaskType.TabularClassification.value: TabularClassificationReferenceDatasetSchema,
556+
TaskType.TabularRegression.value: TabularRegressionReferenceDatasetSchema,
557+
TaskType.TextClassification.value: TextClassificationReferenceDatasetSchema,
558+
TaskType.LLM.value: LLMReferenceDatasetSchema,
559+
TaskType.LLMNER.value: LLMReferenceDatasetSchema,
560+
TaskType.LLMQuestionAnswering.value: LLMReferenceDatasetSchema,
561+
TaskType.LLMSummarization.value: LLMReferenceDatasetSchema,
562+
TaskType.LLMTranslation.value: LLMReferenceDatasetSchema,
563+
}
564+
565+
def get_obj_type(self, obj):
566+
if obj not in [task_type.value for task_type in TaskType]:
567+
raise ma.ValidationError(f"Unknown object type: {obj.__class__.__name__}")
568+
return obj

openlayer/validators/dataset_validators.py

Lines changed: 46 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -102,28 +102,66 @@ def _validate_dataset_config(self):
102102
103103
Beware of the order of the validations, as it is important.
104104
"""
105+
self._validate_file_existence()
106+
self._load_dataset_config()
107+
self._validate_dataset_label()
108+
self._validate_dataset_schema()
109+
110+
def _validate_file_existence(self):
111+
"""Checks whether the dataset_config_file_path exists."""
105112
# File existence check
106113
if self.dataset_config_file_path:
107114
if not os.path.isfile(os.path.expanduser(self.dataset_config_file_path)):
108115
self.failed_validations.append(
109116
f"File `{self.dataset_config_file_path}` does not exist."
110117
)
111-
else:
118+
119+
def _load_dataset_config(self):
120+
"""Loads the dataset_config_file_path into the `self.dataset_config`
121+
attribute."""
122+
if self.dataset_config_file_path:
123+
try:
112124
with open(
113125
self.dataset_config_file_path, "r", encoding="UTF-8"
114126
) as stream:
115127
self.dataset_config = yaml.safe_load(stream)
128+
except:
129+
self.failed_validations.append(
130+
f"File `{self.dataset_config_file_path}` is not a valid .yaml file."
131+
)
116132

133+
def _validate_dataset_label(self):
134+
"""Checks whether the dataset label is valid."""
117135
if self.dataset_config:
118-
dataset_schema = schemas.DatasetSchema()
119-
try:
120-
dataset_schema.load(
121-
{"task_type": self.task_type.value, **self.dataset_config}
136+
if self.dataset_config.get("label") is None:
137+
self.failed_validations.append(
138+
"Missing value for required property `label` in the dataset config."
122139
)
123-
except ma.ValidationError as err:
124-
self.failed_validations.extend(
125-
self._format_marshmallow_error_message(err)
140+
else:
141+
label = self.dataset_config["label"]
142+
if not isinstance(label, str):
143+
self.failed_validations.append(
144+
"The value of `label` in the dataset config must be a string."
145+
)
146+
147+
def _validate_dataset_schema(self):
148+
"""Checks whether the dataset schema is valid."""
149+
if self.dataset_config:
150+
label = self.dataset_config.get("label")
151+
if label:
152+
dataset_schema = (
153+
schemas.ReferenceDatasetSchema()
154+
if label == "reference"
155+
else schemas.DatasetSchema()
126156
)
157+
try:
158+
dataset_schema.load(
159+
{"task_type": self.task_type.value, **self.dataset_config}
160+
)
161+
except ma.ValidationError as err:
162+
self.failed_validations.extend(
163+
self._format_marshmallow_error_message(err)
164+
)
127165

128166
def _validate_dataset_file(self):
129167
"""Checks whether the dataset file exists and is valid.

0 commit comments

Comments
 (0)