Skip to content

Commit a32693f

Browse files
author
Ziqun Ye
committed
add more unit test
1 parent 8df3ec0 commit a32693f

File tree

2 files changed

+124
-9
lines changed

2 files changed

+124
-9
lines changed

ads/opctl/backend/local.py

Lines changed: 7 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,6 @@
77
import copy
88
import json
99
import os
10-
import shutil
1110
import tempfile
1211
from concurrent.futures import Future, ThreadPoolExecutor
1312
from time import sleep
@@ -19,7 +18,6 @@
1918
from ads.common.decorator.runtime_dependency import (OptionalDependency,
2019
runtime_dependency)
2120
from ads.common.oci_client import OCIClientFactory
22-
from ads.model.datascience_model import DataScienceModel
2321
from ads.model.model_metadata import ModelCustomMetadata
2422
from ads.model.runtime.runtime_info import RuntimeInfo
2523
from ads.opctl import logger
@@ -67,7 +65,6 @@ def __init__(self, config: Dict) -> None:
6765
self.oci_config,
6866
self.profile ,
6967
)
70-
7168
self.client = OCIClientFactory(**self.oci_auth).data_science
7269

7370
def run(self):
@@ -671,21 +668,21 @@ def predict(self) -> None:
671668
bucket_uri = self.config["execution"].get("bucket_uri", None)
672669
timeout = self.config["execution"].get("timeout", None)
673670
logger.info(f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts.")
671+
674672
_download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout)
675673

676674
if ocid:
677675
conda_slug, conda_path = self._get_conda_info_from_custom_metadata(ocid)
678-
if artifact_directory or not conda_path:
676+
if not conda_path:
679677
if not os.path.exists(artifact_directory) or len(os.listdir(artifact_directory)) == 0:
680678
raise ValueError(f"`artifact_directory` {artifact_directory} does not exist or is empty.")
681679
conda_slug, conda_path = self._get_conda_info_from_runtime(artifact_dir=artifact_directory)
682-
else:
680+
if not conda_path or not conda_slug:
683681
raise ValueError("Conda information cannot be detected.")
684682
compartment_id = self.config["execution"].get("compartment_id", self.config["infrastructure"].get("compartment_id"))
685683
project_id = self.config["execution"].get("project_id", self.config["infrastructure"].get("project_id"))
686684
if not compartment_id or not project_id:
687685
raise ValueError("`compartment_id` and `project_id` must be provided.")
688-
689686
extra_cmd = "/opt/ds/model/deployed_model/ " + data + " " + compartment_id + " " + project_id
690687
bind_volumes = {}
691688
if not is_in_notebook_session():
@@ -727,13 +724,14 @@ def _get_conda_info_from_custom_metadata(self, ocid):
727724
response = self.client.get_model(ocid)
728725
custom_metadata = ModelCustomMetadata._from_oci_metadata(response.data.custom_metadata_list)
729726
conda_slug, conda_path = None, None
730-
if "CondaEnvironmentPath" in custom_metadata:
727+
if "CondaEnvironmentPath" in custom_metadata.keys:
731728
conda_path = custom_metadata['CondaEnvironmentPath'].value
732-
if "SlugName" in custom_metadata:
729+
if "SlugName" in custom_metadata.keys:
733730
conda_slug = custom_metadata['SlugName'].value
734731
return conda_slug, conda_path
735732

736-
def _get_conda_info_from_runtime(self, artifact_dir):
733+
@staticmethod
734+
def _get_conda_info_from_runtime(artifact_dir):
737735
"""
738736
Get conda env info from runtime yaml file.
739737
Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2021, 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
6+
7+
from mock import ANY, patch, MagicMock
8+
import pytest
9+
from ads.opctl.backend.local import LocalModelDeploymentBackend, ModelCustomMetadata, os
10+
11+
12+
class TestLocalModelDeploymentBackend:
13+
14+
@property
15+
def config(self):
16+
return {
17+
"execution": {
18+
"backend": "local",
19+
"use_conda": True,
20+
"debug": False,
21+
"env_var": ["TEST_ENV=test_env"],
22+
"oci_config": "~/.oci/config",
23+
"oci_profile": "DEFAULT",
24+
"image": "ml-job",
25+
"env_vars": {"TEST_ENV": "test_env"},
26+
"job_name": "hello-world",
27+
"auth": "api_key",
28+
"ocid": "fake_id",
29+
"compartment_id": "fake_id",
30+
"project_id": "fake_id",
31+
"payload": "fake_data"
32+
33+
},
34+
"infrastructure": {},
35+
36+
}
37+
38+
@property
39+
def backend(self):
40+
return LocalModelDeploymentBackend(config=self.config)
41+
42+
@property
43+
def custom_metadata(self):
44+
custom_metadata = ModelCustomMetadata()
45+
custom_metadata.add(key="CondaEnvironmentPath", value="fake_path")
46+
custom_metadata.add(key="SlugName", value="fake_slug")
47+
return custom_metadata
48+
49+
@patch.object(ModelCustomMetadata, "_from_oci_metadata")
50+
def test__get_conda_info_from_custom_metadata(self, mock_custom_metadata, ):
51+
52+
mock_custom_metadata.return_value = self.custom_metadata
53+
54+
backend = LocalModelDeploymentBackend(config=self.config)
55+
backend.client.get_model = MagicMock()
56+
conda_slug, conda_path= backend._get_conda_info_from_custom_metadata("fake_id")
57+
assert conda_slug == "fake_slug"
58+
assert conda_path == "fake_path"
59+
60+
def test__get_conda_info_from_runtime(self):
61+
yaml_str = """
62+
MODEL_ARTIFACT_VERSION: '3.0'
63+
MODEL_DEPLOYMENT:
64+
INFERENCE_CONDA_ENV:
65+
INFERENCE_ENV_PATH: fake_path
66+
INFERENCE_ENV_SLUG: fake_slug
67+
INFERENCE_ENV_TYPE: data_science
68+
INFERENCE_PYTHON_VERSION: '3.8'
69+
MODEL_PROVENANCE:
70+
PROJECT_OCID: ''
71+
TENANCY_OCID: ''
72+
TRAINING_CODE:
73+
ARTIFACT_DIRECTORY: fake_dir
74+
TRAINING_COMPARTMENT_OCID: ''
75+
TRAINING_CONDA_ENV:
76+
TRAINING_ENV_PATH: ''
77+
TRAINING_ENV_SLUG: ''
78+
TRAINING_ENV_TYPE: ''
79+
TRAINING_PYTHON_VERSION: ''
80+
TRAINING_REGION: ''
81+
TRAINING_RESOURCE_OCID: ''
82+
USER_OCID: ''
83+
VM_IMAGE_INTERNAL_ID: ''
84+
"""
85+
with open("./runtime.yaml", "w") as f:
86+
f.write(yaml_str)
87+
88+
conda_slug, conda_path = LocalModelDeploymentBackend._get_conda_info_from_runtime("./")
89+
90+
assert conda_slug == "fake_slug"
91+
assert conda_path == "fake_path"
92+
93+
@patch("ads.opctl.backend.local.os.listdir", return_value=["path"])
94+
@patch("ads.opctl.backend.local.os.path.exists", return_value=True)
95+
def test_predict(self, mock_path_exists, mock_list_dir):
96+
with patch("ads.opctl.backend.local._download_model") as mock__download:
97+
with patch.object(LocalModelDeploymentBackend, "_get_conda_info_from_custom_metadata", return_value = ("fake_slug", "fake_path")):
98+
with patch.object(LocalModelDeploymentBackend, "_get_conda_info_from_runtime"):
99+
with patch.object(LocalModelDeploymentBackend, "_run_with_conda_pack", return_value=0) as mock__run_with_conda_pack:
100+
backend = LocalModelDeploymentBackend(self.config)
101+
backend.predict()
102+
mock__download.assert_not_called()
103+
mock__run_with_conda_pack.assert_called_once_with({os.path.expanduser('~/.oci'): {'bind': '/home/datascience/.oci'}, os.path.expanduser('~/.ads_ops/models/fake_id'): {'bind': '/opt/ds/model/deployed_model/'}}, '/opt/ds/model/deployed_model/ fake_data fake_id fake_id', install=True, conda_uri='fake_path')
104+
105+
106+
@patch("ads.opctl.backend.local.os.listdir", return_value=["path"])
107+
@patch("ads.opctl.backend.local.os.path.exists", return_value=False)
108+
def test_predict_download(self, mock_path_exists, mock_list_dir):
109+
with patch("ads.opctl.backend.local._download_model") as mock__download:
110+
with patch.object(LocalModelDeploymentBackend, "_get_conda_info_from_custom_metadata", return_value = ("fake_slug", "fake_path")):
111+
with patch.object(LocalModelDeploymentBackend, "_get_conda_info_from_runtime"):
112+
with patch.object(LocalModelDeploymentBackend, "_run_with_conda_pack", return_value=0) as mock__run_with_conda_pack:
113+
backend = LocalModelDeploymentBackend(self.config)
114+
backend.predict()
115+
mock__download.assert_called_once_with(ocid='fake_id', artifact_directory=os.path.expanduser('~/.ads_ops/models/fake_id'), region=None, bucket_uri=None, timeout=None)
116+
mock__run_with_conda_pack.assert_called_once_with({os.path.expanduser('~/.oci'): {'bind': '/home/datascience/.oci'}, os.path.expanduser('~/.ads_ops/models/fake_id'): {'bind': '/opt/ds/model/deployed_model/'}}, '/opt/ds/model/deployed_model/ fake_data fake_id fake_id', install=True, conda_uri='fake_path')
117+

0 commit comments

Comments
 (0)