Skip to content

Commit 71edc0f

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Refactor marshmallow schemas and add prompt to LLM production data config
1 parent 2fa8af8 commit 71edc0f

16 files changed

+745
-640
lines changed

openlayer/__init__.py

Lines changed: 15 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from .inference_pipelines import InferencePipeline
4040
from .project_versions import ProjectVersion
4141
from .projects import Project
42-
from .schemas import BaselineModelSchema, DatasetSchema, ModelSchema
42+
from .schemas import dataset_schemas, model_schemas
4343
from .tasks import TaskType
4444
from .validators import (
4545
baseline_model_validators,
@@ -334,7 +334,9 @@ def add_model(
334334
# Load model config and augment with defaults
335335
if model_config_file_path is not None:
336336
model_config = utils.read_yaml(model_config_file_path)
337-
model_data = ModelSchema().load({"task_type": task_type.value, **model_config})
337+
model_data = model_schemas.ModelSchema().load(
338+
{"task_type": task_type.value, **model_config}
339+
)
338340

339341
# Copy relevant resources to temp directory
340342
with tempfile.TemporaryDirectory() as temp_dir:
@@ -432,7 +434,7 @@ def add_baseline_model(
432434
if model_config_file_path is not None:
433435
model_config = utils.read_yaml(model_config_file_path)
434436
model_config["modelType"] = "baseline"
435-
model_data = BaselineModelSchema().load(
437+
model_data = model_schemas.BaselineModelSchema().load(
436438
{"task_type": task_type.value, **model_config}
437439
)
438440

@@ -481,7 +483,7 @@ def add_dataset(
481483
# Load dataset config and augment with defaults
482484
if dataset_config_file_path is not None:
483485
dataset_config = utils.read_yaml(dataset_config_file_path)
484-
dataset_data = DatasetSchema().load(
486+
dataset_data = dataset_schemas.DatasetSchema().load(
485487
{"task_type": task_type.value, **dataset_config}
486488
)
487489
if dataset_data.get("columnNames") is None:
@@ -930,7 +932,7 @@ def create_inference_pipeline(
930932
" upload.",
931933
) from None
932934

933-
reference_dataset_data = DatasetSchema().load(
935+
reference_dataset_data = dataset_schemas.ReferenceDatasetSchema().load(
934936
{"task_type": task_type.value, **reference_dataset_config}
935937
)
936938

@@ -1034,7 +1036,7 @@ def upload_reference_dataset(
10341036
) from None
10351037

10361038
# Load dataset config and augment with defaults
1037-
dataset_data = DatasetSchema().load(
1039+
dataset_data = dataset_schemas.ReferenceDatasetSchema().load(
10381040
{"task_type": task_type.value, **dataset_config}
10391041
)
10401042

@@ -1116,7 +1118,10 @@ def stream_data(
11161118
stream_config, stream_df = self._add_default_columns(
11171119
config=stream_config, df=stream_df
11181120
)
1119-
stream_config = self._strip_read_only_fields(stream_config)
1121+
1122+
# Remove the `label` for the upload
1123+
stream_config.pop("label", None)
1124+
11201125
body = {
11211126
"config": stream_config,
11221127
"rows": stream_df.to_dict(orient="records"),
@@ -1129,13 +1134,6 @@ def stream_data(
11291134
if self.verbose:
11301135
print("Stream published!")
11311136

1132-
def _strip_read_only_fields(self, config: Dict[str, any]) -> Dict[str, any]:
1133-
"""Strips read-only fields from the config."""
1134-
stripped_config = copy.deepcopy(config)
1135-
for field in ["columnNames", "label"]:
1136-
stripped_config.pop(field, None)
1137-
return stripped_config
1138-
11391137
def publish_batch_data(
11401138
self,
11411139
inference_pipeline_id: str,
@@ -1245,7 +1243,9 @@ def _validate_production_data_and_load_config(
12451243
"Make sure to fix all of the issues listed above before the upload.",
12461244
) from None
12471245

1248-
config = DatasetSchema().load({"task_type": task_type.value, **config})
1246+
config = dataset_schemas.ProductionDataSchema().load(
1247+
{"task_type": task_type.value, **config}
1248+
)
12491249

12501250
return config
12511251

openlayer/constants.py

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22
"""
33
import os
44

5+
import marshmallow as ma
6+
57
# ---------------------------- Commit/staging flow --------------------------- #
68
VALID_RESOURCE_NAMES = {"model", "training", "validation", "fine-tuning"}
79
OPENLAYER_DIR = os.path.join(os.path.expanduser("~"), ".openlayer")
@@ -12,3 +14,23 @@
1214

1315
# ----------------------------------- APIs ----------------------------------- #
1416
REQUESTS_TIMEOUT = 60 * 60 * 3 # 3 hours
17+
18+
# ---------------------------- Validation patterns --------------------------- #
19+
COLUMN_NAME_REGEX = validate = ma.validate.Regexp(
20+
r"^(?!openlayer)[a-zA-Z0-9_-]+$",
21+
error="strings that are not alphanumeric with underscores or hyphens."
22+
+ " Spaces and special characters are not allowed."
23+
+ " The string cannot start with `openlayer`.",
24+
)
25+
LANGUAGE_CODE_REGEX = ma.validate.Regexp(
26+
r"^[a-z]{2}(-[A-Z]{2})?$",
27+
error="`language` of the dataset is not in the ISO 639-1 (alpha-2 code) format.",
28+
)
29+
30+
COLUMN_NAME_VALIDATION_LIST = [
31+
ma.validate.Length(
32+
min=1,
33+
max=60,
34+
),
35+
COLUMN_NAME_REGEX,
36+
]

openlayer/datasets.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,14 +22,16 @@ class DatasetType(Enum):
2222
Used by the ``dataset_type`` argument of the :meth:`openlayer.OpenlayerClient.add_dataset` and
2323
:meth:`openlayer.OpenlayerClient.add_dataframe` methods."""
2424

25-
#: For validation sets.
26-
Validation = "validation"
27-
#: For training sets.
28-
Training = "training"
25+
#: For fine-tuning data.
26+
FineTuning = "fine-tuning"
2927
#: For production data.
3028
Production = "production"
3129
#: For reference datasets.
3230
Reference = "reference"
31+
#: For training sets.
32+
Training = "training"
33+
#: For validation sets.
34+
Validation = "validation"
3335

3436

3537
class Dataset:

0 commit comments

Comments
 (0)