Skip to content

Commit ce7db20

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-5005 Accept configs as dicts
1 parent 4208c18 commit ce7db20

File tree

4 files changed

+132
-36
lines changed

4 files changed

+132
-36
lines changed

openlayer/__init__.py

Lines changed: 78 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
import time
2929
import uuid
3030
import warnings
31-
from typing import Optional
31+
from typing import Dict, Optional
3232

3333
import pandas as pd
3434
import yaml
@@ -252,8 +252,9 @@ def create_or_load_project(
252252

253253
def add_model(
254254
self,
255-
model_config_file_path: str,
256255
task_type: TaskType,
256+
model_config: Optional[Dict[str, any]] = None,
257+
model_config_file_path: Optional[str] = None,
257258
model_package_dir: Optional[str] = None,
258259
sample_data: Optional[pd.DataFrame] = None,
259260
force: bool = False,
@@ -263,8 +264,19 @@ def add_model(
263264
264265
Parameters
265266
----------
267+
model_config : Dict[str, any]
268+
Dictionary containing the model configuration. This is not needed if
269+
``model_config_file_path`` is provided.
270+
271+
.. admonition:: What's in the model config dict?
272+
273+
The model configuration depends on the :obj:`TaskType`.
274+
Refer to the `documentation <https://docs.openlayer.com/docs/tabular-classification-model-config>`_
275+
for examples.
276+
266277
model_config_file_path : str
267-
Path to the model configuration YAML file.
278+
Path to the model configuration YAML file. This is not needed if
279+
``model_config`` is provided.
268280
269281
.. admonition:: What's in the model config file?
270282
@@ -407,10 +419,15 @@ def add_model(
407419
"The sample data must contain at least 2 rows, but only"
408420
f"{len(sample_data)} rows were provided."
409421
)
422+
if model_config is None and model_config_file_path is None:
423+
raise ValueError(
424+
"Either `model_config` or `model_config_file_path` must be provided."
425+
)
410426

411427
# Validate model package
412428
model_validator = model_validators.get_validator(
413429
task_type=task_type,
430+
model_config=model_config,
414431
model_package_dir=model_package_dir,
415432
model_config_file_path=model_config_file_path,
416433
sample_data=sample_data,
@@ -424,7 +441,8 @@ def add_model(
424441
) from None
425442

426443
# Load model config and augment with defaults
427-
model_config = utils.read_yaml(model_config_file_path)
444+
if model_config_file_path is not None:
445+
model_config = utils.read_yaml(model_config_file_path)
428446
model_data = ModelSchema().load({"task_type": task_type.value, **model_config})
429447

430448
# Copy relevant resources to temp directory
@@ -451,6 +469,7 @@ def add_baseline_model(
451469
self,
452470
project_id: str,
453471
task_type: TaskType,
472+
model_config: Optional[Dict[str, any]] = None,
454473
model_config_file_path: Optional[str] = None,
455474
force: bool = False,
456475
):
@@ -469,9 +488,23 @@ def add_baseline_model(
469488
470489
Parameters
471490
----------
491+
model_config : Dict[str, any], optional
492+
Dictionary containing the model configuration. This is not needed if
493+
``model_config_file_path`` is provided. If none of these are provided,
494+
the default model config will be used.
495+
496+
.. admonition:: What's on the model config file?
497+
498+
For baseline models, the config should contain:
499+
500+
- ``metadata`` : Dict[str, any], default {}
501+
Dictionary containing metadata about the model. This is the
502+
metadata that will be displayed on the Openlayer platform.
503+
472504
model_config_file_path : str, optional
473-
Path to the model configuration YAML file. If not provided, the default
474-
model config will be used.
505+
Path to the model configuration YAML file. This is not needed if
506+
``model_config`` is provided. If none of these are provided,
507+
the default model config will be used.
475508
476509
.. admonition:: What's on the model config file?
477510
@@ -490,9 +523,9 @@ def add_baseline_model(
490523
)
491524

492525
# Validate the baseline model
493-
494526
baseline_model_validator = baseline_model_validators.get_validator(
495527
task_type=task_type,
528+
model_config=model_config,
496529
model_config_file_path=model_config_file_path,
497530
)
498531
failed_validations = baseline_model_validator.validate()
@@ -504,7 +537,7 @@ def add_baseline_model(
504537
) from None
505538

506539
# Load model config and augment with defaults
507-
model_config = {}
540+
model_config = {} or model_config
508541
if model_config_file_path is not None:
509542
model_config = utils.read_yaml(model_config_file_path)
510543
model_config["modelType"] = "baseline"
@@ -527,7 +560,8 @@ def add_dataset(
527560
self,
528561
file_path: str,
529562
task_type: TaskType,
530-
dataset_config_file_path: str,
563+
dataset_config: Optional[Dict[str, any]] = None,
564+
dataset_config_file_path: Optional[str] = None,
531565
project_id: str = None,
532566
force: bool = False,
533567
):
@@ -537,8 +571,19 @@ def add_dataset(
537571
----------
538572
file_path : str
539573
Path to the csv file containing the dataset.
574+
dataset_config: Dict[str, any]
575+
Dictionary containing the dataset configuration. This is not needed if
576+
``dataset_config_file_path`` is provided.
577+
578+
.. admonition:: What's in the dataset config?
579+
580+
The dataset configuration depends on the :obj:`TaskType`.
581+
Refer to the `documentation <https://docs.openlayer.com/docs/tabular-classification-dataset-config>`_
582+
for examples.
583+
540584
dataset_config_file_path : str
541-
Path to the dataset configuration YAML file.
585+
Path to the dataset configuration YAML file. This is not needed if
586+
``dataset_config`` is provided.
542587
543588
.. admonition:: What's in the dataset config file?
544589
@@ -668,9 +713,15 @@ def add_dataset(
668713
>>> project.commit("Initial dataset commit.")
669714
>>> project.push()
670715
"""
716+
if dataset_config is None and dataset_config_file_path is None:
717+
raise ValueError(
718+
"Either `dataset_config` or `dataset_config_file_path` must be"
719+
" provided."
720+
)
671721
# Validate dataset
672722
dataset_validator = dataset_validators.get_validator(
673723
task_type=task_type,
724+
dataset_config=dataset_config,
674725
dataset_config_file_path=dataset_config_file_path,
675726
dataset_file_path=file_path,
676727
)
@@ -683,7 +734,8 @@ def add_dataset(
683734
) from None
684735

685736
# Load dataset config and augment with defaults
686-
dataset_config = utils.read_yaml(dataset_config_file_path)
737+
if dataset_config_file_path is not None:
738+
dataset_config = utils.read_yaml(dataset_config_file_path)
687739
dataset_data = DatasetSchema().load(
688740
{"task_type": task_type.value, **dataset_config}
689741
)
@@ -704,7 +756,8 @@ def add_dataframe(
704756
self,
705757
dataset_df: pd.DataFrame,
706758
task_type: TaskType,
707-
dataset_config_file_path: str,
759+
dataset_config: Optional[Dict[str, any]] = None,
760+
dataset_config_file_path: Optional[str] = None,
708761
project_id: str = None,
709762
force: bool = False,
710763
):
@@ -714,8 +767,19 @@ def add_dataframe(
714767
----------
715768
dataset_df : pd.DataFrame
716769
Dataframe containing your dataset.
770+
dataset_config: Dict[str, any]
771+
Dictionary containing the dataset configuration. This is not needed if
772+
``dataset_config_file_path`` is provided.
773+
774+
.. admonition:: What's in the dataset config?
775+
776+
The dataset configuration depends on the :obj:`TaskType`.
777+
Refer to the `documentation <https://docs.openlayer.com/docs/tabular-classification-dataset-config>`_
778+
for examples.
779+
717780
dataset_config_file_path : str
718-
Path to the dataset configuration YAML file.
781+
Path to the dataset configuration YAML file. This is not needed if
782+
``dataset_config`` is provided.
719783
720784
.. admonition:: What's in the dataset config file?
721785
@@ -856,6 +920,7 @@ def add_dataframe(
856920
file_path=file_path,
857921
project_id=project_id,
858922
dataset_config_file_path=dataset_config_file_path,
923+
dataset_config=dataset_config,
859924
force=force,
860925
task_type=task_type,
861926
)

openlayer/schemas.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -404,6 +404,7 @@ class ProjectSchema(ma.Schema):
404404
min=1,
405405
max=140,
406406
),
407+
allow_none=True,
407408
)
408409
name = ma.fields.Str(
409410
required=True,

openlayer/validators/baseline_model_validators.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
"""
33
import logging
44
import os
5-
from typing import List, Optional
5+
from typing import Dict, List, Optional
66

77
import marshmallow as ma
88
import yaml
@@ -20,15 +20,21 @@ class BaseBaselineModelValidator(BaseValidator):
2020
----------
2121
task_type : tasks.TaskType
2222
The task type.
23+
model_config : Optional[Dict[str, any]], optional
24+
The model config, by default None
2325
model_config_file_path : Optional[str], optional
2426
The path to the model config file, by default None
2527
"""
2628

2729
def __init__(
28-
self, task_type: tasks.TaskType, model_config_file_path: Optional[str] = None
30+
self,
31+
task_type: tasks.TaskType,
32+
model_config: Optional[Dict[str, any]] = None,
33+
model_config_file_path: Optional[str] = None,
2934
):
3035
super().__init__(resource_display_name="baseline model")
3136
self.task_type = task_type
37+
self.model_config = model_config
3238
self.model_config_file_path = model_config_file_path
3339

3440
def _validate(self) -> List[str]:
@@ -38,7 +44,7 @@ def _validate(self) -> List[str]:
3844
List[str]
3945
The list of failed validations.
4046
"""
41-
if self.model_config_file_path:
47+
if self.model_config_file_path or self.model_config:
4248
self._validate_model_config()
4349

4450
def _validate_model_config(self):
@@ -51,13 +57,13 @@ def _validate_model_config(self):
5157
)
5258
else:
5359
with open(self.model_config_file_path, "r", encoding="UTF-8") as stream:
54-
model_config = yaml.safe_load(stream)
60+
self.model_config = yaml.safe_load(stream)
5561

56-
if model_config:
62+
if self.model_config:
5763
baseline_model_schema = schemas.BaselineModelSchema()
5864
try:
5965
baseline_model_schema.load(
60-
{"task_type": self.task_type.value, **model_config}
66+
{"task_type": self.task_type.value, **self.model_config}
6167
)
6268
except ma.ValidationError as err:
6369
self.failed_validations.extend(
@@ -74,13 +80,15 @@ class TabularClassificationBaselineModelValidator(BaseBaselineModelValidator):
7480
# ----------------------------- Factory function ----------------------------- #
7581
def get_validator(
7682
task_type: tasks.TaskType,
77-
model_config_file_path: str,
83+
model_config: Optional[Dict[str, any]] = None,
84+
model_config_file_path: Optional[str] = None,
7885
) -> BaseBaselineModelValidator:
7986
"""Factory function to get the correct baseline model validator.
8087
8188
Parameters
8289
----------
8390
task_type: The task type of the model.
91+
model_config: The model config.
8492
model_config_file_path: Path to the model config file.
8593
8694
Returns
@@ -89,6 +97,7 @@ def get_validator(
8997
"""
9098
if task_type == tasks.TaskType.TabularClassification:
9199
return TabularClassificationBaselineModelValidator(
100+
model_config=model_config,
92101
model_config_file_path=model_config_file_path,
93102
task_type=task_type,
94103
)

0 commit comments

Comments
 (0)