Skip to content

Commit 54511ba

Browse files
gustavocidornelaswhoseoyster
authored andcommitted
Completes OPEN-3572 Implement model runner
1 parent 804a329 commit 54511ba

File tree

4 files changed

+317
-2
lines changed

4 files changed

+317
-2
lines changed

openlayer/__init__.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -410,6 +410,11 @@ def add_model(
410410
with tempfile.TemporaryDirectory() as temp_dir:
411411
if model_package_dir:
412412
shutil.copytree(model_package_dir, temp_dir, dirs_exist_ok=True)
413+
current_file_dir = os.path.dirname(os.path.abspath(__file__))
414+
shutil.copy(
415+
f"{current_file_dir}/prediction_job.py",
416+
f"{temp_dir}/prediction_job.py",
417+
)
413418
utils.write_python_version(temp_dir)
414419

415420
utils.write_yaml(model_data, f"{temp_dir}/model_config.yaml")
@@ -483,6 +488,7 @@ def add_baseline_model(
483488

484489
# Copy relevant resources to temp directory
485490
with tempfile.TemporaryDirectory() as temp_dir:
491+
shutil.copy("prediction_job.py", f"{temp_dir}/prediction_job.py")
486492
utils.write_yaml(model_data, f"{temp_dir}/model_config.yaml")
487493

488494
self._stage_resource(

openlayer/models.py

Lines changed: 268 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,9 @@
1+
import os
2+
import subprocess
13
from enum import Enum
4+
from typing import List, Set
5+
import tempfile
6+
import pandas as pd
27

38

49
class ModelType(Enum):
@@ -58,3 +63,266 @@ def to_dict(self):
5863
Dict with object properties.
5964
"""
6065
return self._json
66+
67+
68+
class CondaEnvironment:
69+
"""Conda environment manager abstraction.
70+
71+
Parameters
72+
----------
73+
env_name : str
74+
Name of the conda environment.
75+
requirements_file_path : str
76+
Path to the requirements file.
77+
python_version_file_path : str
78+
Path to the python version file.
79+
logs_file_path : str
80+
Path to the logs file.
81+
"""
82+
83+
def __init__(
84+
self,
85+
env_name: str,
86+
requirements_file_path: str,
87+
python_version_file_path: str,
88+
logs_file_path: str,
89+
):
90+
if not self._conda_available():
91+
raise Exception("Conda is not available on this machine.")
92+
93+
self.env_name = env_name
94+
self.requirements_file_path = requirements_file_path
95+
self.python_version_file_path = python_version_file_path
96+
self._conda_prefix = self._get_conda_prefix()
97+
self._logs_file_path = logs_file_path
98+
self._logs_file = subprocess.PIPE
99+
100+
def __enter__(self):
101+
self._logs_file = open(self._logs_file_path, "wb")
102+
existing_envs = self.get_existing_envs()
103+
if self.env_name in existing_envs:
104+
print(f"Found existing conda environment '{self.env_name}'.")
105+
else:
106+
self.create()
107+
self.install_requirements()
108+
return self
109+
110+
def __exit__(self, exc_type, exc_value, traceback):
111+
self.deactivate()
112+
self._logs_file.close()
113+
114+
def _conda_available(self) -> bool:
115+
"""Checks if conda is available on the machine."""
116+
if os.environ.get("CONDA_EXE") is None:
117+
return False
118+
return True
119+
120+
def _get_conda_prefix(self) -> str:
121+
"""Gets the conda base environment prefix.
122+
123+
E.g., '~/miniconda3' or '~/anaconda3'
124+
"""
125+
prefix = subprocess.check_output(["conda", "info", "--base"])
126+
return prefix.decode("UTF-8").strip()
127+
128+
def create(self):
129+
"""Creates a conda environment with the specified name and python version."""
130+
print(f"Creating a new conda environment '{self.env_name}'...")
131+
132+
with open(
133+
self.python_version_file_path, "r", encoding="UTF-8"
134+
) as python_version_file:
135+
python_version = python_version_file.read().split(".")[:2]
136+
python_version = ".".join(python_version)
137+
138+
try:
139+
subprocess.check_call(
140+
[
141+
"conda",
142+
"create",
143+
"-n",
144+
f"{self.env_name}",
145+
f"python={python_version}",
146+
"--yes",
147+
],
148+
stdout=self._logs_file,
149+
stderr=self._logs_file,
150+
)
151+
except subprocess.CalledProcessError as err:
152+
raise Exception(
153+
f"Failed to create conda environment '{self.env_name}' with python "
154+
f"version {python_version}."
155+
f"Error {err.returncode}: {err.output}"
156+
) from None
157+
158+
def delete(self):
159+
"""Deletes the conda environment with the specified name."""
160+
print(f"Deleting conda environment '{self.env_name}'...")
161+
162+
try:
163+
subprocess.check_call(
164+
["conda", "env", "remove", "-n", f"{self.env_name}", "--yes"],
165+
stdout=self._logs_file,
166+
stderr=self._logs_file,
167+
)
168+
except subprocess.CalledProcessError as err:
169+
raise Exception(
170+
f"Failed to delete conda environment '{self.env_name}'."
171+
f"Error {err.returncode}: {err.output}"
172+
) from None
173+
174+
def get_existing_envs(self) -> Set[str]:
175+
"""Gets the names of all existing conda environments."""
176+
print("Checking existing conda environments...")
177+
list_envs_command = """
178+
conda env list | awk '{print $1}'
179+
"""
180+
try:
181+
envs = subprocess.check_output(
182+
list_envs_command,
183+
shell=True,
184+
stderr=self._logs_file,
185+
)
186+
except subprocess.CalledProcessError as err:
187+
raise Exception(
188+
f"Failed to list conda environments."
189+
f"Error {err.returncode}: {err.output}"
190+
) from None
191+
envs = set(envs.decode("UTF-8").split("\n"))
192+
return envs
193+
194+
def activate(self):
195+
"""Activates the conda environment with the specified name."""
196+
print(f"Activating conda environment '{self.env_name}'...")
197+
198+
activation_command = f"""
199+
eval $(conda shell.bash hook)
200+
source {self._conda_prefix}/etc/profile.d/conda.sh
201+
conda activate {self.env_name}"""
202+
203+
try:
204+
subprocess.check_call(
205+
activation_command,
206+
stdout=self._logs_file,
207+
stderr=self._logs_file,
208+
shell=True,
209+
)
210+
except subprocess.CalledProcessError as err:
211+
raise Exception(
212+
f"Failed to activate conda environment '{self.env_name}'."
213+
f"Error {err.returncode}: {err.output}"
214+
) from None
215+
216+
def deactivate(self):
217+
"""Deactivates the conda environment with the specified name."""
218+
print(f"Deactivating conda environment '{self.env_name}'...")
219+
220+
deactivation_command = f"""
221+
eval $(conda shell.bash hook)
222+
source {self._conda_prefix}/etc/profile.d/conda.sh
223+
conda deactivate"""
224+
225+
try:
226+
subprocess.check_call(
227+
deactivation_command,
228+
shell=True,
229+
stdout=self._logs_file,
230+
stderr=self._logs_file,
231+
)
232+
except subprocess.CalledProcessError as err:
233+
raise Exception(
234+
f"Failed to deactivate conda environment '{self.env_name}'."
235+
f"Error {err.returncode}: {err.output}"
236+
) from None
237+
238+
def install_requirements(self):
239+
"""Installs the requirements from the specified requirements file."""
240+
print(f"Installing requirements in conda environment '{self.env_name}'...")
241+
242+
try:
243+
self.run_commands(
244+
["pip", "install", "-r", self.requirements_file_path],
245+
)
246+
except subprocess.CalledProcessError as err:
247+
raise Exception(
248+
f"Failed to install requirements from {self.requirements_file_path}."
249+
f"Error {err.returncode}: {err.output}"
250+
) from None
251+
252+
def run_commands(self, commands: List[str]):
253+
"""Runs the specified commands inside the conda environment.
254+
255+
Parameters
256+
----------
257+
commands : List[str]
258+
List of commands to run.
259+
"""
260+
full_command = f"""
261+
eval $(conda shell.bash hook)
262+
source {self._conda_prefix}/etc/profile.d/conda.sh
263+
conda activate {self.env_name}
264+
{" ".join(commands)}
265+
"""
266+
subprocess.check_call(
267+
full_command,
268+
shell=True,
269+
stdout=self._logs_file,
270+
stderr=self._logs_file,
271+
)
272+
273+
274+
class ModelRunner:
275+
"""Wraps the model package and provides a uniform run method."""
276+
277+
def __init__(self, model_package: str):
278+
self.model_package = model_package
279+
# TODO: change env name to the model id
280+
self._conda_environment = CondaEnvironment(
281+
env_name="new-openlayer",
282+
requirements_file_path=f"{model_package}/requirements.txt",
283+
python_version_file_path=f"{model_package}/python_version",
284+
logs_file_path=f"{model_package}/logs.txt",
285+
)
286+
287+
def run(self, input_data: pd.DataFrame) -> pd.DataFrame:
288+
"""Runs the input data through the model in the conda
289+
environment.
290+
291+
Parameters
292+
----------
293+
input_data : pd.DataFrame
294+
Input data to run the model on.
295+
296+
Returns
297+
-------
298+
pd.DataFrame
299+
Output from the model. The output is a dataframe with a single
300+
column named 'prediction' and lists of class probabilities as values.
301+
"""
302+
with tempfile.TemporaryDirectory() as temp_dir:
303+
# Save the input data to a csv file
304+
input_data.to_csv(f"{temp_dir}/input_data.csv", index=False)
305+
306+
# Run the model in the conda environment
307+
with self._conda_environment as env:
308+
try:
309+
env.run_commands(
310+
[
311+
"python",
312+
f"{self.model_package}/prediction_job.py",
313+
"--input",
314+
f"{temp_dir}/input_data.csv",
315+
"--output",
316+
f"{temp_dir}/output_data.csv",
317+
]
318+
)
319+
except subprocess.CalledProcessError as err:
320+
raise Exception(
321+
f"Failed to run the model in conda environment '{env.env_name}'."
322+
f"Error {err.returncode}: {err.output}"
323+
) from None
324+
325+
# Read the output data from the csv file
326+
output_data = pd.read_csv(f"{temp_dir}/output_data.csv")
327+
328+
return output_data

openlayer/prediction_job.py

Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
"""Script that runs the prediction job.
2+
3+
This file will get copied into the model package when the user uploads a model.
4+
5+
The input and output are written to csv files in
6+
the path specified by the --input and --output flags.
7+
8+
Example usage:
9+
python prediction_job.py --input /path/to/input.csv --output /path/to/output.csv
10+
"""
11+
import argparse
12+
import logging
13+
14+
import pandas as pd
15+
import prediction_interface
16+
17+
if __name__ == "__main__":
18+
# Parse args
19+
logging.info("Parsing args to run the prediction job...")
20+
parser = argparse.ArgumentParser()
21+
parser.add_argument("--input", action="store", dest="input_data_file_path")
22+
parser.add_argument("--output", action="store", dest="output_data_file_path")
23+
args = parser.parse_args()
24+
25+
# Load input data
26+
logging.info("Loading input data...")
27+
input_data = pd.read_csv(args.input_data_file_path)
28+
29+
# Load model module
30+
logging.info("Loading model...")
31+
ml_model = prediction_interface.load_model()
32+
33+
# Run model
34+
logging.info("Running model...")
35+
output_data = pd.DataFrame(
36+
{"predictions": ml_model.predict_proba(input_data).tolist()}
37+
)
38+
39+
# Save output data
40+
logging.info("Saving output data...")
41+
output_data.to_csv(args.output_data_file_path, index=False)

openlayer/validators.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -834,7 +834,7 @@ def validate(self) -> List[str]:
834834
self._validate_dataset_and_config_consistency()
835835

836836
if not self.failed_validations:
837-
print("All dataset validations passed!")
837+
print(f"All {self.dataset_config['label']} dataset validations passed!")
838838

839839
return self.failed_validations
840840

@@ -1126,7 +1126,7 @@ def validate(self) -> List[str]:
11261126
self._validate_model_config()
11271127

11281128
if not self.failed_validations:
1129-
print("All validations passed!")
1129+
print("All model validations passed!")
11301130

11311131
return self.failed_validations
11321132

0 commit comments

Comments
 (0)