2828import time
2929import uuid
3030import warnings
31- from typing import Optional
31+ from typing import Dict , Optional
3232
3333import pandas as pd
3434import yaml
@@ -252,8 +252,9 @@ def create_or_load_project(
252252
253253 def add_model (
254254 self ,
255- model_config_file_path : str ,
256255 task_type : TaskType ,
256+ model_config : Optional [Dict [str , any ]] = None ,
257+ model_config_file_path : Optional [str ] = None ,
257258 model_package_dir : Optional [str ] = None ,
258259 sample_data : Optional [pd .DataFrame ] = None ,
259260 force : bool = False ,
@@ -263,8 +264,19 @@ def add_model(
263264
264265 Parameters
265266 ----------
267+ model_config : Dict[str, any]
268+ Dictionary containing the model configuration. This is not needed if
269+ ``model_config_file_path`` is provided.
270+
271+ .. admonition:: What's in the model config dict?
272+
273+ The model configuration depends on the :obj:`TaskType`.
274+ Refer to the `documentation <https://docs.openlayer.com/docs/tabular-classification-model-config>`_
275+ for examples.
276+
266277 model_config_file_path : str
267- Path to the model configuration YAML file.
278+ Path to the model configuration YAML file. This is not needed if
279+ ``model_config`` is provided.
268280
269281 .. admonition:: What's in the model config file?
270282
@@ -407,10 +419,15 @@ def add_model(
407419 "The sample data must contain at least 2 rows, but only"
408420 f"{ len (sample_data )} rows were provided."
409421 )
422+ if model_config is None and model_config_file_path is None :
423+ raise ValueError (
424+ "Either `model_config` or `model_config_file_path` must be provided."
425+ )
410426
411427 # Validate model package
412428 model_validator = model_validators .get_validator (
413429 task_type = task_type ,
430+ model_config = model_config ,
414431 model_package_dir = model_package_dir ,
415432 model_config_file_path = model_config_file_path ,
416433 sample_data = sample_data ,
@@ -424,7 +441,8 @@ def add_model(
424441 ) from None
425442
426443 # Load model config and augment with defaults
427- model_config = utils .read_yaml (model_config_file_path )
444+ if model_config_file_path is not None :
445+ model_config = utils .read_yaml (model_config_file_path )
428446 model_data = ModelSchema ().load ({"task_type" : task_type .value , ** model_config })
429447
430448 # Copy relevant resources to temp directory
@@ -451,6 +469,7 @@ def add_baseline_model(
451469 self ,
452470 project_id : str ,
453471 task_type : TaskType ,
472+ model_config : Optional [Dict [str , any ]] = None ,
454473 model_config_file_path : Optional [str ] = None ,
455474 force : bool = False ,
456475 ):
@@ -469,9 +488,23 @@ def add_baseline_model(
469488
470489 Parameters
471490 ----------
491+ model_config : Dict[str, any], optional
492+ Dictionary containing the model configuration. This is not needed if
493+ ``model_config_file_path`` is provided. If none of these are provided,
494+ the default model config will be used.
495+
496+ .. admonition:: What's on the model config file?
497+
498+ For baseline models, the config should contain:
499+
500+ - ``metadata`` : Dict[str, any], default {}
501+ Dictionary containing metadata about the model. This is the
502+ metadata that will be displayed on the Openlayer platform.
503+
472504 model_config_file_path : str, optional
473- Path to the model configuration YAML file. If not provided, the default
474- model config will be used.
505+ Path to the model configuration YAML file. This is not needed if
506+ ``model_config`` is provided. If none of these are provided,
507+ the default model config will be used.
475508
476509 .. admonition:: What's on the model config file?
477510
@@ -490,9 +523,9 @@ def add_baseline_model(
490523 )
491524
492525 # Validate the baseline model
493-
494526 baseline_model_validator = baseline_model_validators .get_validator (
495527 task_type = task_type ,
528+ model_config = model_config ,
496529 model_config_file_path = model_config_file_path ,
497530 )
498531 failed_validations = baseline_model_validator .validate ()
@@ -504,7 +537,7 @@ def add_baseline_model(
504537 ) from None
505538
506539 # Load model config and augment with defaults
507- model_config = {}
540+ model_config = {} or model_config
508541 if model_config_file_path is not None :
509542 model_config = utils .read_yaml (model_config_file_path )
510543 model_config ["modelType" ] = "baseline"
@@ -527,7 +560,8 @@ def add_dataset(
527560 self ,
528561 file_path : str ,
529562 task_type : TaskType ,
530- dataset_config_file_path : str ,
563+ dataset_config : Optional [Dict [str , any ]] = None ,
564+ dataset_config_file_path : Optional [str ] = None ,
531565 project_id : str = None ,
532566 force : bool = False ,
533567 ):
@@ -537,8 +571,19 @@ def add_dataset(
537571 ----------
538572 file_path : str
539573 Path to the csv file containing the dataset.
574+ dataset_config: Dict[str, any]
575+ Dictionary containing the dataset configuration. This is not needed if
576+ ``dataset_config_file_path`` is provided.
577+
578+ .. admonition:: What's in the dataset config?
579+
580+ The dataset configuration depends on the :obj:`TaskType`.
581+ Refer to the `documentation <https://docs.openlayer.com/docs/tabular-classification-dataset-config>`_
582+ for examples.
583+
540584 dataset_config_file_path : str
541- Path to the dataset configuration YAML file.
585+ Path to the dataset configuration YAML file. This is not needed if
586+ ``dataset_config`` is provided.
542587
543588 .. admonition:: What's in the dataset config file?
544589
@@ -668,9 +713,15 @@ def add_dataset(
668713 >>> project.commit("Initial dataset commit.")
669714 >>> project.push()
670715 """
716+ if dataset_config is None and dataset_config_file_path is None :
717+ raise ValueError (
718+ "Either `dataset_config` or `dataset_config_file_path` must be"
719+ " provided."
720+ )
671721 # Validate dataset
672722 dataset_validator = dataset_validators .get_validator (
673723 task_type = task_type ,
724+ dataset_config = dataset_config ,
674725 dataset_config_file_path = dataset_config_file_path ,
675726 dataset_file_path = file_path ,
676727 )
@@ -683,7 +734,8 @@ def add_dataset(
683734 ) from None
684735
685736 # Load dataset config and augment with defaults
686- dataset_config = utils .read_yaml (dataset_config_file_path )
737+ if dataset_config_file_path is not None :
738+ dataset_config = utils .read_yaml (dataset_config_file_path )
687739 dataset_data = DatasetSchema ().load (
688740 {"task_type" : task_type .value , ** dataset_config }
689741 )
@@ -704,7 +756,8 @@ def add_dataframe(
704756 self ,
705757 dataset_df : pd .DataFrame ,
706758 task_type : TaskType ,
707- dataset_config_file_path : str ,
759+ dataset_config : Optional [Dict [str , any ]] = None ,
760+ dataset_config_file_path : Optional [str ] = None ,
708761 project_id : str = None ,
709762 force : bool = False ,
710763 ):
@@ -714,8 +767,19 @@ def add_dataframe(
714767 ----------
715768 dataset_df : pd.DataFrame
716769 Dataframe containing your dataset.
770+ dataset_config: Dict[str, any]
771+ Dictionary containing the dataset configuration. This is not needed if
772+ ``dataset_config_file_path`` is provided.
773+
774+ .. admonition:: What's in the dataset config?
775+
776+ The dataset configuration depends on the :obj:`TaskType`.
777+ Refer to the `documentation <https://docs.openlayer.com/docs/tabular-classification-dataset-config>`_
778+ for examples.
779+
717780 dataset_config_file_path : str
718- Path to the dataset configuration YAML file.
781+ Path to the dataset configuration YAML file. This is not needed if
782+ ``dataset_config`` is provided.
719783
720784 .. admonition:: What's in the dataset config file?
721785
@@ -856,6 +920,7 @@ def add_dataframe(
856920 file_path = file_path ,
857921 project_id = project_id ,
858922 dataset_config_file_path = dataset_config_file_path ,
923+ dataset_config = dataset_config ,
859924 force = force ,
860925 task_type = task_type ,
861926 )
0 commit comments