2121import pkg_resources
2222import yaml
2323
24- from . import schemas , utils
24+ from . import models , schemas , utils
2525
2626
2727class BaselineModelValidator :
@@ -96,18 +96,22 @@ class CommitBundleValidator:
9696 Whether to skip model validation, by default False
9797 skip_dataset_validation : bool
9898 Whether to skip dataset validation, by default False
99+ use_runner : bool
100+ Whether to use the runner to validate the model, by default False.
99101 """
100102
101103 def __init__ (
102104 self ,
103105 bundle_path : str ,
104106 skip_model_validation : bool = False ,
105107 skip_dataset_validation : bool = False ,
108+ use_runner : bool = False ,
106109 ):
107110 self .bundle_path = bundle_path
108111 self ._bundle_resources = utils .list_resources_in_bundle (bundle_path )
109112 self ._skip_model_validation = skip_model_validation
110113 self ._skip_dataset_validation = skip_dataset_validation
114+ self ._use_runner = use_runner
111115 self .failed_validations = []
112116
113117 def _validate_bundle_state (self ):
@@ -268,6 +272,7 @@ def _validate_bundle_resources(self):
268272 model_config_file_path = f"{ self .bundle_path } /model/model_config.yaml" ,
269273 model_package_dir = f"{ self .bundle_path } /model" ,
270274 sample_data = sample_data ,
275+ use_runner = self ._use_runner ,
271276 )
272277 bundle_resources_failed_validations .extend (model_validator .validate ())
273278
@@ -844,6 +849,8 @@ class ModelValidator:
844849
845850 Parameters
846851 ----------
852+ model_config_file_path: str
853+ Path to the model config file.
847854 model_package_dir : str
848855 Path to the model package directory.
849856 sample_data : pd.DataFrame
@@ -862,6 +869,7 @@ class ModelValidator:
862869 >>> from openlayer import ModelValidator
863870 >>>
864871 >>> model_validator = ModelValidator(
872+ ... model_config_file_path="/path/to/model/config/file",
865873 ... model_package_dir="/path/to/model/package",
866874 ... sample_data=df,
867875 ... )
@@ -872,12 +880,14 @@ class ModelValidator:
872880 def __init__ (
873881 self ,
874882 model_config_file_path : str ,
883+ use_runner : bool = False ,
875884 model_package_dir : Optional [str ] = None ,
876885 sample_data : Optional [pd .DataFrame ] = None ,
877886 ):
878887 self .model_config_file_path = model_config_file_path
879888 self .model_package_dir = model_package_dir
880889 self .sample_data = sample_data
890+ self ._use_runner = use_runner
881891 self .failed_validations = []
882892
883893 def _validate_model_package_dir (self ):
@@ -932,7 +942,7 @@ def _validate_model_package_dir(self):
932942 # Add the model package failed validations to the list of all failed validations
933943 self .failed_validations .extend (model_package_failed_validations )
934944
935- def _validate_requirements (self ):
945+ def _validate_requirements_file (self ):
936946 """Validates the requirements.txt file.
937947
938948 Checks for the existence of the file and parses it to check for
@@ -1109,6 +1119,33 @@ def _validate_prediction_interface(self):
11091119 # Add the `prediction_interface.py` failed validations to the list of all failed validations
11101120 self .failed_validations .extend (prediction_interface_failed_validations )
11111121
1122+ def _validate_model_runner (self ):
1123+ """Validates the model using the model runner.
1124+
1125+ This is mostly meant to be used by the platform, to validate the model. It will
1126+ create the model's environment and use it to run the model.
1127+ """
1128+ model_runner_failed_validations = []
1129+
1130+ model_runner = models .ModelRunner (self .model_package_dir )
1131+
1132+ # Try to run some data through the runner
1133+ # Will create the model environment if it doesn't exist
1134+ try :
1135+ model_runner .run (self .sample_data )
1136+ except Exception as exc :
1137+ model_runner_failed_validations .append (
1138+ f"Failed to run the model with the following error: \n { exc } "
1139+ )
1140+
1141+ # Print results of the validation
1142+ if model_runner_failed_validations :
1143+ print ("Model runner failed validations: \n " )
1144+ _list_failed_validation_messages (model_runner_failed_validations )
1145+
1146+ # Add the model runner failed validations to the list of all failed validations
1147+ self .failed_validations .extend (model_runner_failed_validations )
1148+
11121149 def validate (self ) -> List [str ]:
11131150 """Runs all model validations.
11141151
@@ -1121,8 +1158,11 @@ def validate(self) -> List[str]:
11211158 """
11221159 if self .model_package_dir :
11231160 self ._validate_model_package_dir ()
1124- self ._validate_requirements ()
1125- self ._validate_prediction_interface ()
1161+ if self ._use_runner :
1162+ self ._validate_model_runner ()
1163+ else :
1164+ self ._validate_requirements_file ()
1165+ self ._validate_prediction_interface ()
11261166 self ._validate_model_config ()
11271167
11281168 if not self .failed_validations :
0 commit comments