Skip to content

Commit bf4a836

Browse files
author
Ziqun Ye
committed
reorgnize code
1 parent 5c6fa49 commit bf4a836

File tree

1 file changed

+6
-47
lines changed

1 file changed

+6
-47
lines changed

ads/opctl/backend/local.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@
2121
from ads.common.oci_client import OCIClientFactory
2222
from ads.model.datascience_model import DataScienceModel
2323
from ads.model.model_metadata import ModelCustomMetadata
24+
from ads.model.runtime.runtime_info import RuntimeInfo
2425
from ads.opctl import logger
2526
from ads.opctl.backend.base import Backend
2627
from ads.opctl.conda.cmds import _install
@@ -33,13 +34,13 @@
3334
DEFAULT_NOTEBOOK_SESSION_SPARK_CONF_DIR,
3435
ML_JOB_GPU_IMAGE, ML_JOB_IMAGE)
3536
from ads.opctl.distributed.cmds import load_ini, local_run
37+
from ads.opctl.model.cmds import _download_model
3638
from ads.opctl.spark.cmds import (generate_core_site_properties,
3739
generate_core_site_properties_str)
3840
from ads.opctl.utils import (build_image, get_docker_client,
3941
is_in_notebook_session, run_command,
4042
run_container)
4143
from ads.pipeline.ads_pipeline import Pipeline, PipelineStep
42-
from ads.model.runtime.runtime_info import RuntimeInfo
4344

4445

4546
class CondaPackNotFound(Exception):
@@ -654,7 +655,7 @@ def predict(self) -> None:
654655
bucket_uri = self.config["execution"].get("bucket_uri", None)
655656
timeout = self.config["execution"].get("timeout", None)
656657
logger.info(f"No cached model found. Downloading the model {ocid} to {artifact_directory}. If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts.")
657-
self._download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout)
658+
_download_model(ocid=ocid, artifact_directory=artifact_directory, region=region, bucket_uri=bucket_uri, timeout=timeout)
658659

659660
if ocid:
660661
conda_slug, conda_path = self._get_conda_info_from_catalog(ocid)
@@ -682,10 +683,7 @@ def predict(self) -> None:
682683
self.config["execution"]["source_folder"] = os.path.abspath(os.path.join(dir_path, ".."))
683684
self.config["execution"]["entrypoint"] = script
684685
bind_volumes[artifact_directory] = {"bind": "/opt/ds/model/deployed_model/"}
685-
686-
if self.config["execution"].get("image"):
687-
exit_code = self._run_with_image(bind_volumes)
688-
elif self.config["execution"].get("conda_slug", conda_slug):
686+
if self.config["execution"].get("conda_slug", conda_slug):
689687
self.config["execution"]["image"] = ML_JOB_IMAGE
690688
if not self.config["execution"].get("conda_slug"):
691689
self.config["execution"]["conda_slug"] = conda_slug
@@ -700,28 +698,7 @@ def predict(self) -> None:
700698
f"`predict` did not complete successfully. Exit code: {exit_code}. "
701699
f"Run with the --debug argument to view container logs."
702700
)
703-
704-
def _download_model(self, ocid, artifact_directory, region, bucket_uri, timeout):
705-
os.makedirs(artifact_directory, exist_ok=True)
706-
os.chmod(artifact_directory, 777)
707-
708-
try:
709-
dsc_model = DataScienceModel.from_id(ocid)
710-
dsc_model.download_artifact(
711-
target_dir=artifact_directory,
712-
force_overwrite=True,
713-
overwrite_existing_artifact=True,
714-
remove_existing_artifact=True,
715-
auth=self.oci_auth,
716-
region=region,
717-
timeout=timeout,
718-
bucket_uri=bucket_uri,
719-
)
720-
721-
except Exception as e :
722-
shutil.rmtree(artifact_directory, ignore_errors=True)
723-
raise e
724-
701+
725702
def _get_conda_info_from_catalog(self, ocid):
726703
response = self.client.get_model(ocid)
727704
custom_metadata = ModelCustomMetadata._from_oci_metadata(response.data.custom_metadata_list)
@@ -735,22 +712,4 @@ def _get_conda_info_from_runtime(self, artifact_dir):
735712
conda_slug = runtime_info.model_deployment.inference_conda_env.inference_env_slug
736713
conda_path = runtime_info.model_deployment.inference_conda_env.inference_env_path
737714
return conda_slug, conda_path
738-
739-
740-
def _run_with_image(self, bind_volumes):
741-
ocid = self.config["execution"].get("ocid")
742-
data = self.config["execution"].get("data")
743-
image = self.config["execution"].get("image")
744-
env_vars = self.config["execution"]["env_vars"]
745-
# compartment_id = self.config["execution"].get("compartment_id", self.config["infrastructure"].get("compartment_id"))
746-
# project_id = self.config["execution"].get("project_id", self.config["infrastructure"].get("project_id"))
747-
entrypoint = self.config["execution"].get("entrypoint", None)
748-
command = self.config["execution"].get("command", None)
749-
if self.config["execution"].get("source_folder", None):
750-
bind_volumes.update(self._mount_source_folder_if_exists(bind_volumes))
751-
bind_volumes.update(self.config["execution"]["volumes"])
752-
753-
return run_container(image, bind_volumes, env_vars, command, entrypoint)
754-
755-
def _run_with_local_env(self, ):
756-
pass
715+

0 commit comments

Comments
 (0)