@@ -147,8 +147,8 @@ def _validate_bundle_state(self):
147147 """Checks whether the bundle is in a valid state.
148148
149149 This includes:
150- - When a "model" is included, you always need to provide predictions for both
151- "validation" and "training" (regardless of artifact or no artifact) .
150+ - When a "model" (shell or full) is included, you always need to provide predictions for both
151+ "validation" and "training".
152152 - When a "baseline-model" is included, you always need to provide a "training"
153153 and "validation" set without predictions.
154154 - When a "model" nor a "baseline-model" are included, you always need to NOT
@@ -186,33 +186,35 @@ def _validate_bundle_state(self):
186186 )
187187
188188 if "model" in self ._bundle_resources :
189+ model_config = self ._load_model_config_from_bundle ()
190+ model_type = model_config .get ("modelType" )
189191 if (
190192 training_predictions_column_name is None
191193 or validation_predictions_column_name is None
192- ):
194+ ) and model_type != "baseline" :
193195 bundle_state_failed_validations .append (
194196 "To push a model to the platform, you must provide "
195197 "training and a validation sets with predictions in the column "
196198 "`predictions_column_name`."
197199 )
198- elif "baseline-model" in self . _bundle_resources :
199- if (
200- "training" not in self ._bundle_resources
201- or "validation" not in self ._bundle_resources
202- ):
203- bundle_state_failed_validations .append (
204- "To push a baseline model to the platform, you must provide "
205- "training and validation sets."
206- )
207- elif (
208- training_predictions_column_name is not None
209- and validation_predictions_column_name is not None
210- ):
211- bundle_state_failed_validations .append (
212- "To push a baseline model to the platform, you must not provide "
213- "training and a validation sets without predictions in the column "
214- "`predictions_column_name`."
215- )
200+ if model_type == "baseline" :
201+ if (
202+ "training" not in self ._bundle_resources
203+ or "validation" not in self ._bundle_resources
204+ ):
205+ bundle_state_failed_validations .append (
206+ "To push a baseline model to the platform, you must provide "
207+ "training and validation sets."
208+ )
209+ elif (
210+ training_predictions_column_name is not None
211+ and validation_predictions_column_name is not None
212+ ):
213+ bundle_state_failed_validations .append (
214+ "To push a baseline model to the platform, you must provide "
215+ "training and validation sets without predictions in the column "
216+ "`predictions_column_name`."
217+ )
216218 else :
217219 if (
218220 "training" in self ._bundle_resources
@@ -260,26 +262,15 @@ def _validate_bundle_resources(self):
260262 validation_set_validator .validate ()
261263 )
262264
263- if (
264- "baseline-model" in self ._bundle_resources
265- and not self ._skip_model_validation
266- ):
267- baseline_model_validator = BaselineModelValidator (
268- model_config_file_path = f"{ self .bundle_path } /baseline-model/model_config.yaml"
269- )
270- bundle_resources_failed_validations .extend (
271- baseline_model_validator .validate ()
272- )
273-
274265 if "model" in self ._bundle_resources and not self ._skip_model_validation :
275- model_files = os .listdir (f"{ self .bundle_path } /model" )
276- # Shell model
277- if len (model_files ) == 1 :
266+ model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
267+ model_config = self ._load_model_config_from_bundle ()
268+
269+ if model_config ["modelType" ] == "shell" :
278270 model_validator = ModelValidator (
279- model_config_file_path = f" { self . bundle_path } /model/model_config.yaml"
271+ model_config_file_path = model_config_file_path
280272 )
281- # Model package
282- else :
273+ elif model_config ["modelType" ] == "full" :
283274 # Use data from the validation as test data
284275 validation_dataset_df = self ._load_dataset_from_bundle ("validation" )
285276 validation_dataset_config = self ._load_dataset_config_from_bundle (
@@ -298,12 +289,21 @@ def _validate_bundle_resources(self):
298289 ].head ()
299290
300291 model_validator = ModelValidator (
301- model_config_file_path = f" { self . bundle_path } /model/model_config.yaml" ,
292+ model_config_file_path = model_config_file_path ,
302293 model_package_dir = f"{ self .bundle_path } /model" ,
303294 sample_data = sample_data ,
304295 use_runner = self ._use_runner ,
305296 )
306- bundle_resources_failed_validations .extend (model_validator .validate ())
297+ elif model_config ["modelType" ] == "baseline" :
298+ model_validator = BaselineModelValidator (
299+ model_config_file_path = model_config_file_path
300+ )
301+ else :
302+ raise ValueError (
303+ f"Invalid model type: { model_config ['modelType' ]} . "
304+ "The model type must be one of 'shell', 'full' or 'baseline'."
305+ )
306+ bundle_resources_failed_validations .extend (model_validator .validate ())
307307
308308 # Add the bundle resources failed validations to the list of all failed validations
309309 self .failed_validations .extend (bundle_resources_failed_validations )
@@ -347,6 +347,21 @@ def _load_dataset_config_from_bundle(self, label: str) -> Dict[str, Any]:
347347
348348 return dataset_config
349349
350+ def _load_model_config_from_bundle (self ) -> Dict [str , Any ]:
351+ """Loads a model config from a commit bundle.
352+
353+ Returns
354+ -------
355+ Dict[str, Any]
356+ The model config.
357+ """
358+ model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
359+
360+ with open (model_config_file_path , "r" , encoding = "UTF-8" ) as stream :
361+ model_config = yaml .safe_load (stream )
362+
363+ return model_config
364+
350365 def validate (self ) -> List [str ]:
351366 """Validates the commit bundle.
352367
0 commit comments