Skip to content

Commit 617a95b

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Closes OPEN-3602 Optionally use the model runner for model validations
1 parent c899f69 commit 617a95b

File tree

3 files changed

+57
-10
lines changed

3 files changed

+57
-10
lines changed

openlayer/__init__.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -410,11 +410,6 @@ def add_model(
410410
with tempfile.TemporaryDirectory() as temp_dir:
411411
if model_package_dir:
412412
shutil.copytree(model_package_dir, temp_dir, dirs_exist_ok=True)
413-
current_file_dir = os.path.dirname(os.path.abspath(__file__))
414-
shutil.copy(
415-
f"{current_file_dir}/prediction_job.py",
416-
f"{temp_dir}/prediction_job.py",
417-
)
418413
utils.write_python_version(temp_dir)
419414

420415
utils.write_yaml(model_data, f"{temp_dir}/model_config.yaml")

openlayer/models.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
11
import os
2+
import shutil
23
import subprocess
4+
import tempfile
35
from enum import Enum
46
from typing import List, Set
5-
import tempfile
7+
68
import pandas as pd
79

810

@@ -284,6 +286,9 @@ def __init__(self, model_package: str):
284286
logs_file_path=f"{model_package}/logs.txt",
285287
)
286288

289+
def __del__(self):
290+
self._conda_environment.delete()
291+
287292
def run(self, input_data: pd.DataFrame) -> pd.DataFrame:
288293
"""Runs the input data through the model in the conda
289294
environment.
@@ -299,6 +304,13 @@ def run(self, input_data: pd.DataFrame) -> pd.DataFrame:
299304
Output from the model. The output is a dataframe with a single
300305
column named 'prediction' and lists of class probabilities as values.
301306
"""
307+
# Copy the prediction job script to the model package
308+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
309+
shutil.copy(
310+
f"{current_file_dir}/prediction_job.py",
311+
f"{self.model_package}/prediction_job.py",
312+
)
313+
302314
with tempfile.TemporaryDirectory() as temp_dir:
303315
# Save the input data to a csv file
304316
input_data.to_csv(f"{temp_dir}/input_data.csv", index=False)

openlayer/validators.py

Lines changed: 44 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
import pkg_resources
2222
import yaml
2323

24-
from . import schemas, utils
24+
from . import models, schemas, utils
2525

2626

2727
class BaselineModelValidator:
@@ -96,18 +96,22 @@ class CommitBundleValidator:
9696
Whether to skip model validation, by default False
9797
skip_dataset_validation : bool
9898
Whether to skip dataset validation, by default False
99+
use_runner : bool
100+
Whether to use the runner to validate the model, by default False.
99101
"""
100102

101103
def __init__(
102104
self,
103105
bundle_path: str,
104106
skip_model_validation: bool = False,
105107
skip_dataset_validation: bool = False,
108+
use_runner: bool = False,
106109
):
107110
self.bundle_path = bundle_path
108111
self._bundle_resources = utils.list_resources_in_bundle(bundle_path)
109112
self._skip_model_validation = skip_model_validation
110113
self._skip_dataset_validation = skip_dataset_validation
114+
self._use_runner = use_runner
111115
self.failed_validations = []
112116

113117
def _validate_bundle_state(self):
@@ -268,6 +272,7 @@ def _validate_bundle_resources(self):
268272
model_config_file_path=f"{self.bundle_path}/model/model_config.yaml",
269273
model_package_dir=f"{self.bundle_path}/model",
270274
sample_data=sample_data,
275+
use_runner=self._use_runner,
271276
)
272277
bundle_resources_failed_validations.extend(model_validator.validate())
273278

@@ -844,6 +849,8 @@ class ModelValidator:
844849
845850
Parameters
846851
----------
852+
model_config_file_path: str
853+
Path to the model config file.
847854
model_package_dir : str
848855
Path to the model package directory.
849856
sample_data : pd.DataFrame
@@ -862,6 +869,7 @@ class ModelValidator:
862869
>>> from openlayer import ModelValidator
863870
>>>
864871
>>> model_validator = ModelValidator(
872+
... model_config_file_path="/path/to/model/config/file",
865873
... model_package_dir="/path/to/model/package",
866874
... sample_data=df,
867875
... )
@@ -872,12 +880,14 @@ class ModelValidator:
872880
def __init__(
873881
self,
874882
model_config_file_path: str,
883+
use_runner: bool = False,
875884
model_package_dir: Optional[str] = None,
876885
sample_data: Optional[pd.DataFrame] = None,
877886
):
878887
self.model_config_file_path = model_config_file_path
879888
self.model_package_dir = model_package_dir
880889
self.sample_data = sample_data
890+
self._use_runner = use_runner
881891
self.failed_validations = []
882892

883893
def _validate_model_package_dir(self):
@@ -932,7 +942,7 @@ def _validate_model_package_dir(self):
932942
# Add the model package failed validations to the list of all failed validations
933943
self.failed_validations.extend(model_package_failed_validations)
934944

935-
def _validate_requirements(self):
945+
def _validate_requirements_file(self):
936946
"""Validates the requirements.txt file.
937947
938948
Checks for the existence of the file and parses it to check for
@@ -1109,6 +1119,33 @@ def _validate_prediction_interface(self):
11091119
# Add the `prediction_interface.py` failed validations to the list of all failed validations
11101120
self.failed_validations.extend(prediction_interface_failed_validations)
11111121

1122+
def _validate_model_runner(self):
1123+
"""Validates the model using the model runner.
1124+
1125+
This is mostly meant to be used by the platform, to validate the model. It will
1126+
create the model's environment and use it to run the model.
1127+
"""
1128+
model_runner_failed_validations = []
1129+
1130+
model_runner = models.ModelRunner(self.model_package_dir)
1131+
1132+
# Try to run some data through the runner
1133+
# Will create the model environment if it doesn't exist
1134+
try:
1135+
model_runner.run(self.sample_data)
1136+
except Exception as exc:
1137+
model_runner_failed_validations.append(
1138+
f"Failed to run the model with the following error: \n {exc}"
1139+
)
1140+
1141+
# Print results of the validation
1142+
if model_runner_failed_validations:
1143+
print("Model runner failed validations: \n")
1144+
_list_failed_validation_messages(model_runner_failed_validations)
1145+
1146+
# Add the model runner failed validations to the list of all failed validations
1147+
self.failed_validations.extend(model_runner_failed_validations)
1148+
11121149
def validate(self) -> List[str]:
11131150
"""Runs all model validations.
11141151
@@ -1121,8 +1158,11 @@ def validate(self) -> List[str]:
11211158
"""
11221159
if self.model_package_dir:
11231160
self._validate_model_package_dir()
1124-
self._validate_requirements()
1125-
self._validate_prediction_interface()
1161+
if self._use_runner:
1162+
self._validate_model_runner()
1163+
else:
1164+
self._validate_requirements_file()
1165+
self._validate_prediction_interface()
11261166
self._validate_model_config()
11271167

11281168
if not self.failed_validations:

0 commit comments

Comments
 (0)