Skip to content

Commit 8df3ec0

Browse files
author
Ziqun Ye
committed
add test
1 parent bf4a836 commit 8df3ec0

File tree

5 files changed

+77
-9
lines changed

5 files changed

+77
-9
lines changed

ads/opctl/backend/local.py

Lines changed: 40 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -642,9 +642,25 @@ def _log_orchestration_message(self, str: str) -> None:
642642

643643
class LocalModelDeploymentBackend(LocalBackend):
644644
def __init__(self, config: Dict) -> None:
645+
"""
646+
Initialize a LocalModelDeploymentBackend object with given config.
647+
648+
Parameters
649+
----------
650+
config: dict
651+
dictionary of configurations
652+
"""
645653
super().__init__(config)
646654

647655
def predict(self) -> None:
656+
"""
657+
Conducts local verify.
658+
659+
Returns
660+
-------
661+
None
662+
Nothing.
663+
"""
648664
artifact_directory = self.config["execution"].get("artifact_directory")
649665
ocid = self.config["execution"].get("ocid")
650666
data = self.config["execution"].get("payload")
@@ -658,8 +674,8 @@ def predict(self) -> None:
658674
_download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout)
659675

660676
if ocid:
661-
conda_slug, conda_path = self._get_conda_info_from_catalog(ocid)
662-
elif artifact_directory:
677+
conda_slug, conda_path = self._get_conda_info_from_custom_metadata(ocid)
678+
if artifact_directory or not conda_path:
663679
if not os.path.exists(artifact_directory) or len(os.listdir(artifact_directory)) == 0:
664680
raise ValueError(f"`artifact_directory` {artifact_directory} does not exist or is empty.")
665681
conda_slug, conda_path = self._get_conda_info_from_runtime(artifact_dir=artifact_directory)
@@ -699,14 +715,33 @@ def predict(self) -> None:
699715
f"Run with the --debug argument to view container logs."
700716
)
701717

702-
def _get_conda_info_from_catalog(self, ocid):
718+
def _get_conda_info_from_custom_metadata(self, ocid):
719+
"""
720+
Get conda env info from custom metadata from model catalog.
721+
722+
Returns
723+
-------
724+
(str, str)
725+
conda slug and conda path.
726+
"""
703727
response = self.client.get_model(ocid)
704728
custom_metadata = ModelCustomMetadata._from_oci_metadata(response.data.custom_metadata_list)
705-
conda_path = custom_metadata['CondaEnvironmentPath'].value
706-
conda_slug = custom_metadata['SlugName'].value
729+
conda_slug, conda_path = None, None
730+
if "CondaEnvironmentPath" in custom_metadata:
731+
conda_path = custom_metadata['CondaEnvironmentPath'].value
732+
if "SlugName" in custom_metadata:
733+
conda_slug = custom_metadata['SlugName'].value
707734
return conda_slug, conda_path
708735

709736
def _get_conda_info_from_runtime(self, artifact_dir):
737+
"""
738+
Get conda env info from runtime yaml file.
739+
740+
Returns
741+
-------
742+
(str, str)
743+
conda slug and conda path.
744+
"""
710745
runtime_yaml_file = os.path.join(artifact_dir, "runtime.yaml")
711746
runtime_info = RuntimeInfo.from_yaml(uri=runtime_yaml_file)
712747
conda_slug = runtime_info.model_deployment.inference_conda_env.inference_env_slug

ads/opctl/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,7 @@
2828
from ads.opctl.utils import suppress_traceback
2929
from ads.opctl.config.merger import ConfigMerger
3030
from ads.opctl.constants import BACKEND_NAME
31-
from ads.opctl.backend.local import DEFAULT_MODEL_FOLDER
31+
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
3232

3333
import ads.opctl.conda.cli
3434
import ads.opctl.spark.cli

ads/opctl/model/cli.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@
88
from ads.common.auth import AuthType
99
from ads.opctl.utils import suppress_traceback
1010
from ads.opctl.model.cmds import download_model as download_model_cmd
11-
from ads.opctl.backend.local import DEFAULT_MODEL_FOLDER
11+
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
1212

1313

1414
@click.group("model")

ads/opctl/model/cmds.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
from ads.common.auth import create_signer
55
from ads.model.datascience_model import DataScienceModel
66
from ads.opctl import logger
7-
from ads.opctl.backend.local import DEFAULT_MODEL_FOLDER
7+
from ads.opctl.constants import DEFAULT_MODEL_FOLDER
88
from ads.opctl.config.base import ConfigProcessor
99
from ads.opctl.config.merger import ConfigMerger
1010

@@ -55,4 +55,5 @@ def _download_model(ocid, artifact_directory, oci_auth, region, bucket_uri, time
5555
)
5656
except Exception as e:
5757
print(type(e))
58-
shutil.rmtree(artifact_directory, ignore_errors=True)
58+
shutil.rmtree(artifact_directory, ignore_errors=True)
59+
raise e
Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,32 @@
1+
from ads.opctl.model.cmds import _download_model, download_model
2+
import pytest
3+
from unittest.mock import ANY, call, patch
4+
from ads.model.datascience_model import DataScienceModel
5+
from unittest.mock import MagicMock, Mock
6+
from ads.opctl.model.cmds import create_signer
7+
8+
9+
@patch.object(DataScienceModel, "from_id")
10+
def test_model__download_model(mock_from_id):
11+
mock_datascience_model = MagicMock()
12+
mock_from_id.return_value = mock_datascience_model
13+
_download_model("fake_model_id", "fake_dir", "fake_auth", "region", "bucket_uri", 36, False)
14+
mock_from_id.assert_called_with("fake_model_id")
15+
mock_datascience_model.download_artifact.assert_called_with(target_dir='fake_dir', force_overwrite=False, overwrite_existing_artifact=True, remove_existing_artifact=True, auth='fake_auth', region='region', timeout=36, bucket_uri='bucket_uri')
16+
17+
18+
@patch.object(DataScienceModel, "from_id", side_effect=Exception("Fake error."))
19+
def test_model__download_model_error(mock_from_id):
20+
with pytest.raises(Exception, match="Fake error."):
21+
_download_model("fake_model_id", "fake_dir", "fake_auth", "region", "bucket_uri", 36, False)
22+
23+
24+
@patch("ads.opctl.model.cmds._download_model")
25+
@patch("ads.opctl.model.cmds.create_signer")
26+
def test_download_model(mock_create_signer, mock__download_model):
27+
auth_mock = MagicMock()
28+
mock_create_signer.return_value = auth_mock
29+
download_model(ocid = "fake_model_id")
30+
mock_create_signer.assert_called_once()
31+
mock__download_model.assert_called_once_with(ocid='fake_model_id', artifact_directory='/Users/ziye/.ads_ops/models/fake_model_id', region=None, bucket_uri=None, timeout=None, force_overwrite=False, oci_auth=auth_mock)
32+

0 commit comments

Comments
 (0)