1818from typing import Any , Dict , List , Optional
1919
2020import marshmallow as ma
21+ import numpy as np
2122import pandas as pd
2223import pkg_resources
2324import yaml
@@ -204,31 +205,25 @@ def _validate_bundle_state(self):
204205 training_predictions_column_name = None
205206 validation_predictions_column_name = None
206207 if "training" in self ._bundle_resources :
207- with open (
208- f"{ self .bundle_path } /training/dataset_config.yaml" ,
209- "r" ,
210- encoding = "UTF-8" ,
211- ) as stream :
212- training_dataset_config = yaml .safe_load (stream )
213-
208+ training_dataset_config = utils .load_dataset_config_from_bundle (
209+ bundle_path = self .bundle_path , label = "training"
210+ )
214211 training_predictions_column_name = training_dataset_config .get (
215212 "predictionsColumnName"
216213 )
217214
218215 if "validation" in self ._bundle_resources :
219- with open (
220- f"{ self .bundle_path } /validation/dataset_config.yaml" ,
221- "r" ,
222- encoding = "UTF-8" ,
223- ) as stream :
224- validation_dataset_config = yaml .safe_load (stream )
225-
216+ validation_dataset_config = utils .load_dataset_config_from_bundle (
217+ bundle_path = self .bundle_path , label = "validation"
218+ )
226219 validation_predictions_column_name = validation_dataset_config .get (
227220 "predictionsColumnName"
228221 )
229222
230223 if "model" in self ._bundle_resources :
231- model_config = self ._load_model_config_from_bundle ()
224+ model_config = utils .load_model_config_from_bundle (
225+ bundle_path = self .bundle_path
226+ )
232227 model_type = model_config .get ("modelType" )
233228 if (
234229 training_predictions_column_name is None
@@ -306,17 +301,21 @@ def _validate_bundle_resources(self):
306301
307302 if "model" in self ._bundle_resources and not self ._skip_model_validation :
308303 model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
309- model_config = self ._load_model_config_from_bundle ()
304+ model_config = utils .load_model_config_from_bundle (
305+ bundle_path = self .bundle_path
306+ )
310307
311308 if model_config ["modelType" ] == "shell" :
312309 model_validator = ModelValidator (
313310 model_config_file_path = model_config_file_path
314311 )
315312 elif model_config ["modelType" ] == "full" :
316313 # Use data from the validation as test data
317- validation_dataset_df = self ._load_dataset_from_bundle ("validation" )
318- validation_dataset_config = self ._load_dataset_config_from_bundle (
319- "validation"
314+ validation_dataset_df = utils .load_dataset_from_bundle (
315+ bundle_path = self .bundle_path , label = "validation"
316+ )
317+ validation_dataset_config = utils .load_dataset_config_from_bundle (
318+ bundle_path = self .bundle_path , label = "validation"
320319 )
321320
322321 sample_data = None
@@ -350,60 +349,6 @@ def _validate_bundle_resources(self):
350349 # Add the bundle resources failed validations to the list of all failed validations
351350 self .failed_validations .extend (bundle_resources_failed_validations )
352351
353- def _load_dataset_from_bundle (self , label : str ) -> pd .DataFrame :
354- """Loads a dataset from a commit bundle.
355-
356- Parameters
357- ----------
358- label : str
359- The type of the dataset. Can be either "training" or "validation".
360-
361- Returns
362- -------
363- pd.DataFrame
364- The dataset.
365- """
366- dataset_file_path = f"{ self .bundle_path } /{ label } /dataset.csv"
367-
368- dataset_df = pd .read_csv (dataset_file_path )
369-
370- return dataset_df
371-
372- def _load_dataset_config_from_bundle (self , label : str ) -> Dict [str , Any ]:
373- """Loads a dataset config from a commit bundle.
374-
375- Parameters
376- ----------
377- label : str
378- The type of the dataset. Can be either "training" or "validation".
379-
380- Returns
381- -------
382- Dict[str, Any]
383- The dataset config.
384- """
385- dataset_config_file_path = f"{ self .bundle_path } /{ label } /dataset_config.yaml"
386-
387- with open (dataset_config_file_path , "r" , encoding = "UTF-8" ) as stream :
388- dataset_config = yaml .safe_load (stream )
389-
390- return dataset_config
391-
392- def _load_model_config_from_bundle (self ) -> Dict [str , Any ]:
393- """Loads a model config from a commit bundle.
394-
395- Returns
396- -------
397- Dict[str, Any]
398- The model config.
399- """
400- model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml"
401-
402- with open (model_config_file_path , "r" , encoding = "UTF-8" ) as stream :
403- model_config = yaml .safe_load (stream )
404-
405- return model_config
406-
407352 def _validate_resource_consistency (self ):
408353 """Validates that the resources in the bundle are consistent with each other.
409354
@@ -419,10 +364,14 @@ def _validate_resource_consistency(self):
419364 # Loading the relevant configs
420365 model_config = {}
421366 if "model" in self ._bundle_resources :
422- model_config = self ._load_model_config_from_bundle ()
423- training_dataset_config = self ._load_dataset_config_from_bundle ("training" )
424- validation_dataset_config = self ._load_dataset_config_from_bundle (
425- "validation"
367+ model_config = utils .load_model_config_from_bundle (
368+ bundle_path = self .bundle_path
369+ )
370+ training_dataset_config = utils .load_dataset_config_from_bundle (
371+ bundle_path = self .bundle_path , label = "training"
372+ )
373+ validation_dataset_config = utils .load_dataset_config_from_bundle (
374+ bundle_path = self .bundle_path , label = "validation"
426375 )
427376 model_feature_names = model_config .get ("featureNames" )
428377 model_class_names = model_config .get ("classNames" )
@@ -1113,6 +1062,8 @@ def __init__(
11131062 self .sample_data = sample_data
11141063 self ._use_runner = use_runner
11151064 self .failed_validations = []
1065+ self .model_config = None
1066+ self .model_output = None
11161067
11171068 def validate (self ) -> List [str ]:
11181069 """Runs all model validations.
@@ -1300,6 +1251,8 @@ def _validate_model_config(self):
13001251 if model_config_failed_validations :
13011252 logger .error ("`model_config.yaml` failed validations:" )
13021253 _list_failed_validation_messages (model_config_failed_validations )
1254+ else :
1255+ self .model_config = model_config
13031256
13041257 # Add the `model_config.yaml` failed validations to the list of all failed validations
13051258 self .failed_validations .extend (model_config_failed_validations )
@@ -1359,7 +1312,9 @@ def _validate_prediction_interface(self):
13591312 # Test `predict_proba` function
13601313 try :
13611314 with utils .HidePrints ():
1362- ml_model .predict_proba (self .sample_data )
1315+ self .model_output = ml_model .predict_proba (
1316+ self .sample_data
1317+ )
13631318 except Exception as exc :
13641319 exception_stack = utils .get_exception_stacktrace (exc )
13651320 prediction_interface_failed_validations .append (
@@ -1368,6 +1323,9 @@ def _validate_prediction_interface(self):
13681323 f"\t { exception_stack } "
13691324 )
13701325
1326+ if self .model_output is not None :
1327+ self ._validate_model_output ()
1328+
13711329 # Print results of the validation
13721330 if prediction_interface_failed_validations :
13731331 logger .error ("`prediction_interface.py` failed validations:" )
@@ -1401,6 +1359,48 @@ def _validate_model_runner(self):
14011359 # Add the model runner failed validations to the list of all failed validations
14021360 self .failed_validations .extend (model_runner_failed_validations )
14031361
1362+ def _validate_model_output (self ):
1363+ """Validates the model output.
1364+
1365+ Checks if the model output is an-array like object with shape (n_samples, n_classes)
1366+ Also checks if the model output is a probability distribution.
1367+ """
1368+ model_output_failed_validations = []
1369+
1370+ # Check if the model output is an array-like object
1371+ if not isinstance (self .model_output , np .ndarray ):
1372+ model_output_failed_validations .append (
1373+ "The output of the `predict_proba` method in the `prediction_interface.py` "
1374+ "file is not an array-like object. It should be a numpy array of shape "
1375+ "(n_samples, n_classes)."
1376+ )
1377+ elif self .model_config is not None :
1378+ # Check if the model output has the correct shape
1379+ num_rows = len (self .sample_data )
1380+ num_classes = len (self .model_config .get ("classes" ))
1381+ if self .model_output .shape != (num_rows , num_classes ):
1382+ model_output_failed_validations .append (
1383+ "The output of the `predict_proba` method in the `prediction_interface.py` "
1384+ " has the wrong shape. It should be a numpy array of shape "
1385+ f"({ num_rows } , { num_classes } ). The current output has shape "
1386+ f"{ self .model_output .shape } "
1387+ )
1388+ # Check if the model output is a probability distribution
1389+ elif not np .allclose (self .model_output .sum (axis = 1 ), 1 , atol = 0.05 ):
1390+ model_output_failed_validations .append (
1391+ "The output of the `predict_proba` method in the `prediction_interface.py` "
1392+ "file is not a probability distribution. The sum of the probabilities for "
1393+ "each sample should be equal to 1."
1394+ )
1395+
1396+ # Print results of the validation
1397+ if model_output_failed_validations :
1398+ logger .error ("Model output failed validations:" )
1399+ _list_failed_validation_messages (model_output_failed_validations )
1400+
1401+ # Add the model output failed validations to the list of all failed validations
1402+ self .failed_validations .extend (model_output_failed_validations )
1403+
14041404
14051405class ProjectValidator :
14061406 """Validates the project.
0 commit comments