Skip to content

Commit ded5a4d

Browse files
author
Ziqun Ye
committed
change based on the comments
1 parent 56a7009 commit ded5a4d

File tree

6 files changed

+282
-119
lines changed

6 files changed

+282
-119
lines changed

ads/opctl/backend/base.py

Lines changed: 10 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,22 +8,24 @@
88
from typing import Dict
99

1010
from ads.common.auth import create_signer
11+
from ads.common.oci_client import OCIClientFactory
1112

1213

1314
class Backend:
1415
"""Interface for backend"""
1516

1617
def __init__(self, config: Dict) -> None:
17-
1818
self.config = config
19-
self.oci_auth = create_signer(
20-
config["execution"].get("auth"),
21-
config["execution"].get("oci_config", None),
22-
config["execution"].get("oci_profile", None),
23-
)
2419
self.auth_type = config["execution"].get("auth")
2520
self.profile = config["execution"].get("oci_profile", None)
21+
self.oci_config = config["execution"].get("oci_config", None)
2622

23+
self.oci_auth = create_signer(
24+
self.auth_type,
25+
self.oci_config,
26+
self.profile,
27+
)
28+
self.client = OCIClientFactory(**self.oci_auth).data_science
2729

2830
@abstractmethod
2931
def run(self) -> Dict:
@@ -98,10 +100,10 @@ def run_diagnostics(self):
98100

99101
def predict(self) -> None:
100102
"""
101-
Deactivate a remote service.
103+
Run model predict.
102104
103105
Returns
104106
-------
105107
None
106108
"""
107-
raise NotImplementedError("`predict` has not been implemented yet.")
109+
raise NotImplementedError("`predict` has not been implemented yet.")

ads/opctl/backend/local.py

Lines changed: 115 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -15,33 +15,45 @@
1515
from oci.data_science.models import PipelineStepRun
1616

1717
from ads.common.auth import create_signer
18-
from ads.common.decorator.runtime_dependency import (OptionalDependency,
19-
runtime_dependency)
20-
from ads.common.oci_client import OCIClientFactory
18+
from ads.common.decorator.runtime_dependency import (
19+
OptionalDependency,
20+
runtime_dependency,
21+
)
22+
2123
from ads.model.model_metadata import ModelCustomMetadata
2224
from ads.model.runtime.runtime_info import RuntimeInfo
2325
from ads.opctl import logger
2426
from ads.opctl.backend.base import Backend
2527
from ads.opctl.conda.cmds import _install
2628
from ads.opctl.config.resolver import ConfigResolver
27-
from ads.opctl.constants import (DEFAULT_IMAGE_CONDA_DIR,
28-
DEFAULT_IMAGE_HOME_DIR,
29-
DEFAULT_IMAGE_SCRIPT_DIR,
30-
DEFAULT_MODEL_FOLDER,
31-
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
32-
DEFAULT_NOTEBOOK_SESSION_SPARK_CONF_DIR,
33-
ML_JOB_GPU_IMAGE, ML_JOB_IMAGE)
29+
from ads.opctl.constants import (
30+
DEFAULT_IMAGE_CONDA_DIR,
31+
DEFAULT_IMAGE_HOME_DIR,
32+
DEFAULT_IMAGE_SCRIPT_DIR,
33+
DEFAULT_MODEL_FOLDER,
34+
DEFAULT_NOTEBOOK_SESSION_CONDA_DIR,
35+
DEFAULT_NOTEBOOK_SESSION_SPARK_CONF_DIR,
36+
ML_JOB_GPU_IMAGE,
37+
ML_JOB_IMAGE,
38+
DEFAULT_MODEL_DEPLOYMENT_FOLDER,
39+
)
3440
from ads.opctl.distributed.cmds import load_ini, local_run
3541
from ads.opctl.model.cmds import _download_model
36-
from ads.opctl.spark.cmds import (generate_core_site_properties,
37-
generate_core_site_properties_str)
38-
from ads.opctl.utils import (build_image, get_docker_client,
39-
is_in_notebook_session, run_command,
40-
run_container)
42+
from ads.opctl.spark.cmds import (
43+
generate_core_site_properties,
44+
generate_core_site_properties_str,
45+
)
46+
from ads.opctl.utils import (
47+
build_image,
48+
get_docker_client,
49+
is_in_notebook_session,
50+
run_command,
51+
run_container,
52+
)
4153
from ads.pipeline.ads_pipeline import Pipeline, PipelineStep
4254

4355

44-
class CondaPackNotFound(Exception): # pragma: no cover
56+
class CondaPackNotFound(Exception): # pragma: no cover
4557
pass
4658

4759

@@ -55,17 +67,7 @@ def __init__(self, config: Dict) -> None:
5567
config: dict
5668
dictionary of configurations
5769
"""
58-
self.config = config
59-
self.auth_type = config["execution"].get("auth")
60-
self.profile = config["execution"].get("oci_profile", None)
61-
self.oci_config = config["execution"].get("oci_config", None)
62-
63-
self.oci_auth = create_signer(
64-
self.auth_type,
65-
self.oci_config,
66-
self.profile ,
67-
)
68-
self.client = OCIClientFactory(**self.oci_auth).data_science
70+
super().__init__(config=config)
6971

7072
def run(self):
7173
if self.config.get("version") == "v1.0":
@@ -190,7 +192,13 @@ def init_vscode_container(self) -> None:
190192
f.write(json.dumps(dev_container, indent=2))
191193
print(f"File {os.path.abspath('.devcontainer.json')} created.")
192194

193-
def _run_with_conda_pack(self, bind_volumes: Dict, extra_cmd: str="", install: bool=False, conda_uri: str="") -> int:
195+
def _run_with_conda_pack(
196+
self,
197+
bind_volumes: Dict,
198+
extra_cmd: str = "",
199+
install: bool = False,
200+
conda_uri: str = "",
201+
) -> int:
194202
env_vars = self.config["execution"].get("env_vars", {})
195203
slug = self.config["execution"]["conda_slug"]
196204
image = self.config["execution"].get("image", None)
@@ -214,7 +222,7 @@ def _run_with_conda_pack(self, bind_volumes: Dict, extra_cmd: str="", install: b
214222
image, slug, command, bind_volumes, env_vars
215223
)
216224

217-
def _build_command_for_conda_run(self, extra_cmd: str="") -> str:
225+
def _build_command_for_conda_run(self, extra_cmd: str = "") -> str:
218226
if ConfigResolver(self.config)._is_ads_operator():
219227
if is_in_notebook_session():
220228
curr_dir = os.path.dirname(os.path.abspath(__file__))
@@ -285,7 +293,7 @@ def _run_with_image(self, bind_volumes: Dict) -> int:
285293
if self.config["execution"].get("source_folder", None):
286294
bind_volumes.update(self._mount_source_folder_if_exists(bind_volumes))
287295
bind_volumes.update(self.config["execution"]["volumes"])
288-
296+
289297
return run_container(image, bind_volumes, env_vars, command, entrypoint)
290298

291299
def _run_with_image_v1(self, bind_volumes: Dict) -> int:
@@ -308,15 +316,22 @@ def _run_with_image_v1(self, bind_volumes: Dict) -> int:
308316
)
309317

310318
def _check_conda_pack_and_install_if_applicable(
311-
self, slug: str, bind_volumes: Dict, env_vars: Dict, install: bool=False, conda_uri: str = None
319+
self,
320+
slug: str,
321+
bind_volumes: Dict,
322+
env_vars: Dict,
323+
install: bool = False,
324+
conda_uri: str = None,
312325
) -> Dict:
313-
conda_pack_folder = os.path.abspath(os.path.expanduser(self.config['execution']["conda_pack_folder"]))
314-
conda_pack_path = os.path.join(
315-
conda_pack_folder, slug
326+
conda_pack_folder = os.path.abspath(
327+
os.path.expanduser(self.config["execution"]["conda_pack_folder"])
316328
)
329+
conda_pack_path = os.path.join(conda_pack_folder, slug)
317330
if not os.path.exists(conda_pack_path):
318331
if install:
319-
logger.info(f"Downloading the conda pack {slug} to this conda pack {conda_pack_folder}. If this conda pack is already installed locally in a different location, pass in `conda_pack_folder` to avoid downloading it again.")
332+
logger.info(
333+
f"Downloading a `{slug}` to the `{conda_pack_folder}`. If this conda pack is already installed locally in a different location, pass in `conda_pack_folder` to avoid downloading it again."
334+
)
320335
_install(
321336
conda_uri=conda_uri,
322337
conda_pack_folder=conda_pack_folder,
@@ -648,7 +663,7 @@ def __init__(self, config: Dict) -> None:
648663
dictionary of configurations
649664
"""
650665
super().__init__(config)
651-
666+
652667
def predict(self) -> None:
653668
"""
654669
Conducts local verify.
@@ -661,29 +676,63 @@ def predict(self) -> None:
661676
artifact_directory = self.config["execution"].get("artifact_directory")
662677
ocid = self.config["execution"].get("ocid")
663678
data = self.config["execution"].get("payload")
664-
model_folder = os.path.expanduser(self.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER))
679+
model_folder = os.path.expanduser(
680+
self.config["execution"].get("model_save_folder", DEFAULT_MODEL_FOLDER)
681+
)
665682
artifact_directory = artifact_directory or os.path.join(model_folder, str(ocid))
666-
if ocid and (not os.path.exists(artifact_directory) or len(os.listdir(artifact_directory)) == 0):
683+
if ocid and (
684+
not os.path.exists(artifact_directory)
685+
or len(os.listdir(artifact_directory)) == 0
686+
):
667687
region = self.config["execution"].get("region", None)
668688
bucket_uri = self.config["execution"].get("bucket_uri", None)
669689
timeout = self.config["execution"].get("timeout", None)
670-
logger.info(f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts.")
671-
672-
_download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout)
690+
logger.info(
691+
f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts."
692+
)
673693

694+
_download_model(
695+
oci_auth=self.oci_auth,
696+
ocid=ocid,
697+
artifact_directory=artifact_directory,
698+
region=region,
699+
bucket_uri=bucket_uri,
700+
timeout=timeout,
701+
force_overwrite=True,
702+
)
703+
conda_slug, conda_path = None, None
674704
if ocid:
675705
conda_slug, conda_path = self._get_conda_info_from_custom_metadata(ocid)
676706
if not conda_path:
677-
if not os.path.exists(artifact_directory) or len(os.listdir(artifact_directory)) == 0:
678-
raise ValueError(f"`artifact_directory` {artifact_directory} does not exist or is empty.")
679-
conda_slug, conda_path = self._get_conda_info_from_runtime(artifact_dir=artifact_directory)
707+
if (
708+
not os.path.exists(artifact_directory)
709+
or len(os.listdir(artifact_directory)) == 0
710+
):
711+
raise ValueError(
712+
f"`artifact_directory` {artifact_directory} does not exist or is empty."
713+
)
714+
conda_slug, conda_path = self._get_conda_info_from_runtime(
715+
artifact_dir=artifact_directory
716+
)
680717
if not conda_path or not conda_slug:
681718
raise ValueError("Conda information cannot be detected.")
682-
compartment_id = self.config["execution"].get("compartment_id", self.config["infrastructure"].get("compartment_id"))
683-
project_id = self.config["execution"].get("project_id", self.config["infrastructure"].get("project_id"))
719+
compartment_id = self.config["execution"].get(
720+
"compartment_id", self.config["infrastructure"].get("compartment_id")
721+
)
722+
project_id = self.config["execution"].get(
723+
"project_id", self.config["infrastructure"].get("project_id")
724+
)
684725
if not compartment_id or not project_id:
685726
raise ValueError("`compartment_id` and `project_id` must be provided.")
686-
extra_cmd = "/opt/ds/model/deployed_model/ " + data + " " + compartment_id + " " + project_id
727+
extra_cmd = (
728+
DEFAULT_MODEL_DEPLOYMENT_FOLDER
729+
+ " "
730+
+ data
731+
+ " "
732+
+ compartment_id
733+
+ " "
734+
+ project_id
735+
)
687736
bind_volumes = {}
688737
if not is_in_notebook_session():
689738
bind_volumes = {
@@ -693,16 +742,20 @@ def predict(self) -> None:
693742
}
694743
dir_path = os.path.dirname(os.path.realpath(__file__))
695744
script = "script.py"
696-
self.config["execution"]["source_folder"] = os.path.abspath(os.path.join(dir_path, ".."))
745+
self.config["execution"]["source_folder"] = os.path.abspath(
746+
os.path.join(dir_path, "..")
747+
)
697748
self.config["execution"]["entrypoint"] = script
698-
bind_volumes[artifact_directory] = {"bind": "/opt/ds/model/deployed_model/"}
749+
bind_volumes[artifact_directory] = {"bind": DEFAULT_MODEL_DEPLOYMENT_FOLDER}
699750
if self.config["execution"].get("conda_slug", conda_slug):
700751
self.config["execution"]["image"] = ML_JOB_IMAGE
701752
if not self.config["execution"].get("conda_slug"):
702753
self.config["execution"]["conda_slug"] = conda_slug
703754
self.config["execution"]["slug"] = conda_slug
704755
self.config["execution"]["conda_path"] = conda_path
705-
exit_code = self._run_with_conda_pack(bind_volumes, extra_cmd, install=True, conda_uri=conda_path)
756+
exit_code = self._run_with_conda_pack(
757+
bind_volumes, extra_cmd, install=True, conda_uri=conda_path
758+
)
706759
else:
707760
raise ValueError("Either conda pack info or image should be specified.")
708761

@@ -722,14 +775,16 @@ def _get_conda_info_from_custom_metadata(self, ocid):
722775
conda slug and conda path.
723776
"""
724777
response = self.client.get_model(ocid)
725-
custom_metadata = ModelCustomMetadata._from_oci_metadata(response.data.custom_metadata_list)
778+
custom_metadata = ModelCustomMetadata._from_oci_metadata(
779+
response.data.custom_metadata_list
780+
)
726781
conda_slug, conda_path = None, None
727782
if "CondaEnvironmentPath" in custom_metadata.keys:
728-
conda_path = custom_metadata['CondaEnvironmentPath'].value
783+
conda_path = custom_metadata["CondaEnvironmentPath"].value
729784
if "SlugName" in custom_metadata.keys:
730-
conda_slug = custom_metadata['SlugName'].value
785+
conda_slug = custom_metadata["SlugName"].value
731786
return conda_slug, conda_path
732-
787+
733788
@staticmethod
734789
def _get_conda_info_from_runtime(artifact_dir):
735790
"""
@@ -742,7 +797,10 @@ def _get_conda_info_from_runtime(artifact_dir):
742797
"""
743798
runtime_yaml_file = os.path.join(artifact_dir, "runtime.yaml")
744799
runtime_info = RuntimeInfo.from_yaml(uri=runtime_yaml_file)
745-
conda_slug = runtime_info.model_deployment.inference_conda_env.inference_env_slug
746-
conda_path = runtime_info.model_deployment.inference_conda_env.inference_env_path
800+
conda_slug = (
801+
runtime_info.model_deployment.inference_conda_env.inference_env_slug
802+
)
803+
conda_path = (
804+
runtime_info.model_deployment.inference_conda_env.inference_env_path
805+
)
747806
return conda_slug, conda_path
748-

0 commit comments

Comments
 (0)