2121from ads .common .oci_client import OCIClientFactory
2222from ads .model .datascience_model import DataScienceModel
2323from ads .model .model_metadata import ModelCustomMetadata
24+ from ads .model .runtime .runtime_info import RuntimeInfo
2425from ads .opctl import logger
2526from ads .opctl .backend .base import Backend
2627from ads .opctl .conda .cmds import _install
3334 DEFAULT_NOTEBOOK_SESSION_SPARK_CONF_DIR ,
3435 ML_JOB_GPU_IMAGE , ML_JOB_IMAGE )
3536from ads .opctl .distributed .cmds import load_ini , local_run
37+ from ads .opctl .model .cmds import _download_model
3638from ads .opctl .spark .cmds import (generate_core_site_properties ,
3739 generate_core_site_properties_str )
3840from ads .opctl .utils import (build_image , get_docker_client ,
3941 is_in_notebook_session , run_command ,
4042 run_container )
4143from ads .pipeline .ads_pipeline import Pipeline , PipelineStep
42- from ads .model .runtime .runtime_info import RuntimeInfo
4344
4445
4546class CondaPackNotFound (Exception ):
@@ -654,7 +655,7 @@ def predict(self) -> None:
654655 bucket_uri = self .config ["execution" ].get ("bucket_uri" , None )
655656 timeout = self .config ["execution" ].get ("timeout" , None )
656657 logger .info (f"No cached model found. Downloading the model { ocid } to { artifact_directory } . If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts." )
657- self . _download_model (ocid = ocid , artifact_directory = artifact_directory , region = region , bucket_uri = bucket_uri , timeout = timeout )
658+ _download_model (ocid = ocid , artifact_directory = artifact_directory , region = region , bucket_uri = bucket_uri , timeout = timeout )
658659
659660 if ocid :
660661 conda_slug , conda_path = self ._get_conda_info_from_catalog (ocid )
@@ -682,10 +683,7 @@ def predict(self) -> None:
682683 self .config ["execution" ]["source_folder" ] = os .path .abspath (os .path .join (dir_path , ".." ))
683684 self .config ["execution" ]["entrypoint" ] = script
684685 bind_volumes [artifact_directory ] = {"bind" : "/opt/ds/model/deployed_model/" }
685-
686- if self .config ["execution" ].get ("image" ):
687- exit_code = self ._run_with_image (bind_volumes )
688- elif self .config ["execution" ].get ("conda_slug" , conda_slug ):
686+ if self .config ["execution" ].get ("conda_slug" , conda_slug ):
689687 self .config ["execution" ]["image" ] = ML_JOB_IMAGE
690688 if not self .config ["execution" ].get ("conda_slug" ):
691689 self .config ["execution" ]["conda_slug" ] = conda_slug
@@ -700,28 +698,7 @@ def predict(self) -> None:
700698 f"`predict` did not complete successfully. Exit code: { exit_code } . "
701699 f"Run with the --debug argument to view container logs."
702700 )
703-
704- def _download_model (self , ocid , artifact_directory , region , bucket_uri , timeout ):
705- os .makedirs (artifact_directory , exist_ok = True )
706- os .chmod (artifact_directory , 777 )
707-
708- try :
709- dsc_model = DataScienceModel .from_id (ocid )
710- dsc_model .download_artifact (
711- target_dir = artifact_directory ,
712- force_overwrite = True ,
713- overwrite_existing_artifact = True ,
714- remove_existing_artifact = True ,
715- auth = self .oci_auth ,
716- region = region ,
717- timeout = timeout ,
718- bucket_uri = bucket_uri ,
719- )
720-
721- except Exception as e :
722- shutil .rmtree (artifact_directory , ignore_errors = True )
723- raise e
724-
701+
725702 def _get_conda_info_from_catalog (self , ocid ):
726703 response = self .client .get_model (ocid )
727704 custom_metadata = ModelCustomMetadata ._from_oci_metadata (response .data .custom_metadata_list )
@@ -735,22 +712,4 @@ def _get_conda_info_from_runtime(self, artifact_dir):
735712 conda_slug = runtime_info .model_deployment .inference_conda_env .inference_env_slug
736713 conda_path = runtime_info .model_deployment .inference_conda_env .inference_env_path
737714 return conda_slug , conda_path
738-
739-
740- def _run_with_image (self , bind_volumes ):
741- ocid = self .config ["execution" ].get ("ocid" )
742- data = self .config ["execution" ].get ("data" )
743- image = self .config ["execution" ].get ("image" )
744- env_vars = self .config ["execution" ]["env_vars" ]
745- # compartment_id = self.config["execution"].get("compartment_id", self.config["infrastructure"].get("compartment_id"))
746- # project_id = self.config["execution"].get("project_id", self.config["infrastructure"].get("project_id"))
747- entrypoint = self .config ["execution" ].get ("entrypoint" , None )
748- command = self .config ["execution" ].get ("command" , None )
749- if self .config ["execution" ].get ("source_folder" , None ):
750- bind_volumes .update (self ._mount_source_folder_if_exists (bind_volumes ))
751- bind_volumes .update (self .config ["execution" ]["volumes" ])
752-
753- return run_container (image , bind_volumes , env_vars , command , entrypoint )
754-
755- def _run_with_local_env (self , ):
756- pass
715+
0 commit comments