3939from .inference_pipelines import InferencePipeline
4040from .project_versions import ProjectVersion
4141from .projects import Project
42- from .schemas import BaselineModelSchema , DatasetSchema , ModelSchema
42+ from .schemas import dataset_schemas , model_schemas
4343from .tasks import TaskType
4444from .validators import (
4545 baseline_model_validators ,
@@ -334,7 +334,9 @@ def add_model(
334334 # Load model config and augment with defaults
335335 if model_config_file_path is not None :
336336 model_config = utils .read_yaml (model_config_file_path )
337- model_data = ModelSchema ().load ({"task_type" : task_type .value , ** model_config })
337+ model_data = model_schemas .ModelSchema ().load (
338+ {"task_type" : task_type .value , ** model_config }
339+ )
338340
339341 # Copy relevant resources to temp directory
340342 with tempfile .TemporaryDirectory () as temp_dir :
@@ -432,7 +434,7 @@ def add_baseline_model(
432434 if model_config_file_path is not None :
433435 model_config = utils .read_yaml (model_config_file_path )
434436 model_config ["modelType" ] = "baseline"
435- model_data = BaselineModelSchema ().load (
437+ model_data = model_schemas . BaselineModelSchema ().load (
436438 {"task_type" : task_type .value , ** model_config }
437439 )
438440
@@ -481,7 +483,7 @@ def add_dataset(
481483 # Load dataset config and augment with defaults
482484 if dataset_config_file_path is not None :
483485 dataset_config = utils .read_yaml (dataset_config_file_path )
484- dataset_data = DatasetSchema ().load (
486+ dataset_data = dataset_schemas . DatasetSchema ().load (
485487 {"task_type" : task_type .value , ** dataset_config }
486488 )
487489 if dataset_data .get ("columnNames" ) is None :
@@ -930,7 +932,7 @@ def create_inference_pipeline(
930932 " upload." ,
931933 ) from None
932934
933- reference_dataset_data = DatasetSchema ().load (
935+ reference_dataset_data = dataset_schemas . ReferenceDatasetSchema ().load (
934936 {"task_type" : task_type .value , ** reference_dataset_config }
935937 )
936938
@@ -1034,7 +1036,7 @@ def upload_reference_dataset(
10341036 ) from None
10351037
10361038 # Load dataset config and augment with defaults
1037- dataset_data = DatasetSchema ().load (
1039+ dataset_data = dataset_schemas . ReferenceDatasetSchema ().load (
10381040 {"task_type" : task_type .value , ** dataset_config }
10391041 )
10401042
@@ -1116,7 +1118,10 @@ def stream_data(
11161118 stream_config , stream_df = self ._add_default_columns (
11171119 config = stream_config , df = stream_df
11181120 )
1119- stream_config = self ._strip_read_only_fields (stream_config )
1121+
1122+ # Remove the `label` for the upload
1123+ stream_config .pop ("label" , None )
1124+
11201125 body = {
11211126 "config" : stream_config ,
11221127 "rows" : stream_df .to_dict (orient = "records" ),
@@ -1129,13 +1134,6 @@ def stream_data(
11291134 if self .verbose :
11301135 print ("Stream published!" )
11311136
1132- def _strip_read_only_fields (self , config : Dict [str , any ]) -> Dict [str , any ]:
1133- """Strips read-only fields from the config."""
1134- stripped_config = copy .deepcopy (config )
1135- for field in ["columnNames" , "label" ]:
1136- stripped_config .pop (field , None )
1137- return stripped_config
1138-
11391137 def publish_batch_data (
11401138 self ,
11411139 inference_pipeline_id : str ,
@@ -1245,7 +1243,9 @@ def _validate_production_data_and_load_config(
12451243 "Make sure to fix all of the issues listed above before the upload." ,
12461244 ) from None
12471245
1248- config = DatasetSchema ().load ({"task_type" : task_type .value , ** config })
1246+ config = dataset_schemas .ProductionDataSchema ().load (
1247+ {"task_type" : task_type .value , ** config }
1248+ )
12491249
12501250 return config
12511251
0 commit comments