Skip to content

Commit f49e17e

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-3569 Unify ModelSchema and ShellModelSchema for the Python API
1 parent 8c97b58 commit f49e17e

File tree

7 files changed

+20
-57
lines changed

7 files changed

+20
-57
lines changed

examples/tabular-classification/sklearn/churn-classifier/churn-classifier-sklearn.ipynb

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -495,7 +495,10 @@
495495
" \"model_type\": \"Logistic Regression\",\n",
496496
" \"regularization\": \"None\",\n",
497497
" \"encoder_used\": \"One Hot\", \n",
498-
" }\n",
498+
" },\n",
499+
" \"classNames\": class_names,\n",
500+
" \"featureNames\": feature_names,\n",
501+
" \"categoricalFeatureNames\": categorical_feature_names,\n",
499502
"}\n",
500503
"\n",
501504
"with open(\"model_config.yaml\", \"w\") as model_config_file:\n",
@@ -835,7 +838,7 @@
835838
"name": "python",
836839
"nbconvert_exporter": "python",
837840
"pygments_lexer": "ipython3",
838-
"version": "3.8.10"
841+
"version": "3.8.13"
839842
}
840843
},
841844
"nbformat": 4,

examples/tabular-classification/sklearn/fetal-health/fetal-health-sklearn.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -423,7 +423,9 @@
423423
" \"metadata\": { # Can add anything here, as long as it is a dict\n",
424424
" \"model_type\": \"Logistic Regression\",\n",
425425
" \"regularization\": \"L1\",\n",
426-
" }\n",
426+
" },\n",
427+
" \"classNames\": class_names,\n",
428+
" \"featureNames\": feature_names,\n",
427429
"}\n",
428430
"\n",
429431
"with open(\"model_config.yaml\", \"w\") as model_config_file:\n",

examples/tabular-classification/sklearn/fraud-detection/fraud-classifier-sklearn.ipynb

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -524,7 +524,10 @@
524524
" \"model_type\": \"Gradient Boosting\",\n",
525525
" \"regularization\": \"None\",\n",
526526
" \"encoder_used\": \"One Hot\", \n",
527-
" }\n",
527+
" },\n",
528+
" \"classNames\": class_names,\n",
529+
" \"featureNames\": feature_names,\n",
530+
" \"categoricalFeatureNames\": categorical_feature_names,\n",
528531
"}\n",
529532
"\n",
530533
"with open(\"model_config.yaml\", \"w\") as model_config_file:\n",

examples/tabular-classification/sklearn/iris-classifier/iris-tabular-sklearn.ipynb

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -375,7 +375,9 @@
375375
" \"metadata\": { # Can add anything here, as long as it is a dict\n",
376376
" \"model_type\": \"Logistic Regression\",\n",
377377
" \"regularization\": \"None\",\n",
378-
" }\n",
378+
" },\n",
379+
" \"classNames\": class_names,\n",
380+
" \"featureNames\": feature_names,\n",
379381
"}\n",
380382
"\n",
381383
"with open(\"model_config.yaml\", \"w\") as model_config_file:\n",

openlayer/__init__.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111

1212
from . import api, exceptions, utils, validators
1313
from .projects import Project
14-
from .schemas import BaselineModelSchema, DatasetSchema, ModelSchema, ShellModelSchema
14+
from .schemas import BaselineModelSchema, DatasetSchema, ModelSchema
1515
from .tasks import TaskType
1616
from .version import __version__ # noqa: F401
1717

@@ -230,24 +230,7 @@ def add_model(
230230
231231
.. admonition:: What's on the model config file?
232232
233-
The content of the YAML file depends on whether you are adding a shell
234-
model or a model package.
235-
236-
**If you are adding a shell model**, the model configuration file
237-
must contain the following fields:
238-
239-
- ``name`` : str
240-
Name of the model.
241-
- ``architectureType`` : str
242-
The model's framework. Must be one of the supported frameworks
243-
on :obj:`ModelType`.
244-
- ``metadata`` : Dict[str, any], default {}
245-
Dictionary containing metadata about the model. This is the
246-
metadata that will be displayed on the Openlayer platform.
247-
248-
**Alternatively, if you are adding a model package** (i.e., with model
249-
artifacts and prediction interface), the model configuration file must
250-
contain the following fields:
233+
The model configuration YAML file must contain the following fields:
251234
252235
- ``name`` : str
253236
Name of the model.
@@ -421,10 +404,7 @@ def add_model(
421404

422405
# Load model config and augment with defaults
423406
model_config = utils.read_yaml(model_config_file_path)
424-
if model_package_dir:
425-
model_data = ModelSchema().load(model_config)
426-
else:
427-
model_data = ShellModelSchema().load(model_config)
407+
model_data = ModelSchema().load(model_config)
428408

429409
# Copy relevant resources to temp directory
430410
with tempfile.TemporaryDirectory() as temp_dir:

openlayer/schemas.py

Lines changed: 1 addition & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ class ModelSchema(ma.Schema):
104104
)
105105
classNames = ma.fields.List(
106106
ma.fields.Str(),
107+
required=True,
107108
)
108109
name = ma.fields.Str(
109110
required=True,
@@ -156,28 +157,3 @@ class ProjectSchema(ma.Schema):
156157
+ " https://reference.openlayer.com/reference/api/openlayer.TaskType.html.\n ",
157158
),
158159
)
159-
160-
161-
class ShellModelSchema(ma.Schema):
162-
"""Schema for models without artifacts (i.e., shell model)."""
163-
164-
name = ma.fields.Str(
165-
required=True,
166-
validate=ma.validate.Length(
167-
min=1,
168-
max=64,
169-
),
170-
)
171-
metadata = ma.fields.Dict(
172-
allow_none=True,
173-
load_default={},
174-
)
175-
architectureType = ma.fields.Str(
176-
validate=ma.validate.OneOf(
177-
[model_framework.value for model_framework in ModelType],
178-
error="`architectureType` must be one of the supported frameworks. "
179-
+ "Check out our API reference for a full list"
180-
+ " https://reference.openlayer.com/reference/api/openlayer.ModelType.html.\n ",
181-
),
182-
required=True,
183-
)

openlayer/validators.py

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1021,10 +1021,7 @@ def _validate_model_config(self):
10211021
with open(self.model_config_file_path, "r", encoding="UTF-8") as stream:
10221022
model_config = yaml.safe_load(stream)
10231023

1024-
if self.model_package_dir:
1025-
model_schema = schemas.ModelSchema()
1026-
else:
1027-
model_schema = schemas.ShellModelSchema()
1024+
model_schema = schemas.ModelSchema()
10281025
try:
10291026
model_schema.load(model_config)
10301027
except ma.ValidationError as err:

0 commit comments

Comments
 (0)