1515from oci .data_science .models import PipelineStepRun
1616
1717from ads .common .auth import create_signer
18- from ads .common .decorator .runtime_dependency import (OptionalDependency ,
19- runtime_dependency )
20- from ads .common .oci_client import OCIClientFactory
18+ from ads .common .decorator .runtime_dependency import (
19+ OptionalDependency ,
20+ runtime_dependency ,
21+ )
22+
2123from ads .model .model_metadata import ModelCustomMetadata
2224from ads .model .runtime .runtime_info import RuntimeInfo
2325from ads .opctl import logger
2426from ads .opctl .backend .base import Backend
2527from ads .opctl .conda .cmds import _install
2628from ads .opctl .config .resolver import ConfigResolver
27- from ads .opctl .constants import (DEFAULT_IMAGE_CONDA_DIR ,
28- DEFAULT_IMAGE_HOME_DIR ,
29- DEFAULT_IMAGE_SCRIPT_DIR ,
30- DEFAULT_MODEL_FOLDER ,
31- DEFAULT_NOTEBOOK_SESSION_CONDA_DIR ,
32- DEFAULT_NOTEBOOK_SESSION_SPARK_CONF_DIR ,
33- ML_JOB_GPU_IMAGE , ML_JOB_IMAGE )
29+ from ads .opctl .constants import (
30+ DEFAULT_IMAGE_CONDA_DIR ,
31+ DEFAULT_IMAGE_HOME_DIR ,
32+ DEFAULT_IMAGE_SCRIPT_DIR ,
33+ DEFAULT_MODEL_FOLDER ,
34+ DEFAULT_NOTEBOOK_SESSION_CONDA_DIR ,
35+ DEFAULT_NOTEBOOK_SESSION_SPARK_CONF_DIR ,
36+ ML_JOB_GPU_IMAGE ,
37+ ML_JOB_IMAGE ,
38+ DEFAULT_MODEL_DEPLOYMENT_FOLDER ,
39+ )
3440from ads .opctl .distributed .cmds import load_ini , local_run
3541from ads .opctl .model .cmds import _download_model
36- from ads .opctl .spark .cmds import (generate_core_site_properties ,
37- generate_core_site_properties_str )
38- from ads .opctl .utils import (build_image , get_docker_client ,
39- is_in_notebook_session , run_command ,
40- run_container )
42+ from ads .opctl .spark .cmds import (
43+ generate_core_site_properties ,
44+ generate_core_site_properties_str ,
45+ )
46+ from ads .opctl .utils import (
47+ build_image ,
48+ get_docker_client ,
49+ is_in_notebook_session ,
50+ run_command ,
51+ run_container ,
52+ )
4153from ads .pipeline .ads_pipeline import Pipeline , PipelineStep
4254
4355
44- class CondaPackNotFound (Exception ): # pragma: no cover
56+ class CondaPackNotFound (Exception ): # pragma: no cover
4557 pass
4658
4759
@@ -55,17 +67,7 @@ def __init__(self, config: Dict) -> None:
5567 config: dict
5668 dictionary of configurations
5769 """
58- self .config = config
59- self .auth_type = config ["execution" ].get ("auth" )
60- self .profile = config ["execution" ].get ("oci_profile" , None )
61- self .oci_config = config ["execution" ].get ("oci_config" , None )
62-
63- self .oci_auth = create_signer (
64- self .auth_type ,
65- self .oci_config ,
66- self .profile ,
67- )
68- self .client = OCIClientFactory (** self .oci_auth ).data_science
70+ super ().__init__ (config = config )
6971
7072 def run (self ):
7173 if self .config .get ("version" ) == "v1.0" :
@@ -190,7 +192,13 @@ def init_vscode_container(self) -> None:
190192 f .write (json .dumps (dev_container , indent = 2 ))
191193 print (f"File { os .path .abspath ('.devcontainer.json' )} created." )
192194
193- def _run_with_conda_pack (self , bind_volumes : Dict , extra_cmd : str = "" , install : bool = False , conda_uri : str = "" ) -> int :
195+ def _run_with_conda_pack (
196+ self ,
197+ bind_volumes : Dict ,
198+ extra_cmd : str = "" ,
199+ install : bool = False ,
200+ conda_uri : str = "" ,
201+ ) -> int :
194202 env_vars = self .config ["execution" ].get ("env_vars" , {})
195203 slug = self .config ["execution" ]["conda_slug" ]
196204 image = self .config ["execution" ].get ("image" , None )
@@ -214,7 +222,7 @@ def _run_with_conda_pack(self, bind_volumes: Dict, extra_cmd: str="", install: b
214222 image , slug , command , bind_volumes , env_vars
215223 )
216224
217- def _build_command_for_conda_run (self , extra_cmd : str = "" ) -> str :
225+ def _build_command_for_conda_run (self , extra_cmd : str = "" ) -> str :
218226 if ConfigResolver (self .config )._is_ads_operator ():
219227 if is_in_notebook_session ():
220228 curr_dir = os .path .dirname (os .path .abspath (__file__ ))
@@ -285,7 +293,7 @@ def _run_with_image(self, bind_volumes: Dict) -> int:
285293 if self .config ["execution" ].get ("source_folder" , None ):
286294 bind_volumes .update (self ._mount_source_folder_if_exists (bind_volumes ))
287295 bind_volumes .update (self .config ["execution" ]["volumes" ])
288-
296+
289297 return run_container (image , bind_volumes , env_vars , command , entrypoint )
290298
291299 def _run_with_image_v1 (self , bind_volumes : Dict ) -> int :
@@ -308,15 +316,22 @@ def _run_with_image_v1(self, bind_volumes: Dict) -> int:
308316 )
309317
310318 def _check_conda_pack_and_install_if_applicable (
311- self , slug : str , bind_volumes : Dict , env_vars : Dict , install : bool = False , conda_uri : str = None
319+ self ,
320+ slug : str ,
321+ bind_volumes : Dict ,
322+ env_vars : Dict ,
323+ install : bool = False ,
324+ conda_uri : str = None ,
312325 ) -> Dict :
313- conda_pack_folder = os .path .abspath (os .path .expanduser (self .config ['execution' ]["conda_pack_folder" ]))
314- conda_pack_path = os .path .join (
315- conda_pack_folder , slug
326+ conda_pack_folder = os .path .abspath (
327+ os .path .expanduser (self .config ["execution" ]["conda_pack_folder" ])
316328 )
329+ conda_pack_path = os .path .join (conda_pack_folder , slug )
317330 if not os .path .exists (conda_pack_path ):
318331 if install :
319- logger .info (f"Downloading the conda pack { slug } to this conda pack { conda_pack_folder } . If this conda pack is already installed locally in a different location, pass in `conda_pack_folder` to avoid downloading it again." )
332+ logger .info (
333+ f"Downloading a `{ slug } ` to the `{ conda_pack_folder } `. If this conda pack is already installed locally in a different location, pass in `conda_pack_folder` to avoid downloading it again."
334+ )
320335 _install (
321336 conda_uri = conda_uri ,
322337 conda_pack_folder = conda_pack_folder ,
@@ -648,7 +663,7 @@ def __init__(self, config: Dict) -> None:
648663 dictionary of configurations
649664 """
650665 super ().__init__ (config )
651-
666+
652667 def predict (self ) -> None :
653668 """
654669 Conducts local verify.
@@ -661,29 +676,63 @@ def predict(self) -> None:
661676 artifact_directory = self .config ["execution" ].get ("artifact_directory" )
662677 ocid = self .config ["execution" ].get ("ocid" )
663678 data = self .config ["execution" ].get ("payload" )
664- model_folder = os .path .expanduser (self .config ["execution" ].get ("model_save_folder" , DEFAULT_MODEL_FOLDER ))
679+ model_folder = os .path .expanduser (
680+ self .config ["execution" ].get ("model_save_folder" , DEFAULT_MODEL_FOLDER )
681+ )
665682 artifact_directory = artifact_directory or os .path .join (model_folder , str (ocid ))
666- if ocid and (not os .path .exists (artifact_directory ) or len (os .listdir (artifact_directory )) == 0 ):
683+ if ocid and (
684+ not os .path .exists (artifact_directory )
685+ or len (os .listdir (artifact_directory )) == 0
686+ ):
667687 region = self .config ["execution" ].get ("region" , None )
668688 bucket_uri = self .config ["execution" ].get ("bucket_uri" , None )
669689 timeout = self .config ["execution" ].get ("timeout" , None )
670- logger .info (f"No cached model found. Downloading the model { ocid } to { artifact_directory } . If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts." )
671-
672- _download_model ( ocid = ocid , artifact_directory = artifact_directory , region = region , bucket_uri = bucket_uri , timeout = timeout )
690+ logger .info (
691+ f"No cached model found. Downloading the model { ocid } to { artifact_directory } . If you already have a copy of the model, specify `artifact_directory` instead of `ocid`. You can specify `model_save_folder` to decide where to store the model artifacts."
692+ )
673693
694+ _download_model (
695+ oci_auth = self .oci_auth ,
696+ ocid = ocid ,
697+ artifact_directory = artifact_directory ,
698+ region = region ,
699+ bucket_uri = bucket_uri ,
700+ timeout = timeout ,
701+ force_overwrite = True ,
702+ )
703+ conda_slug , conda_path = None , None
674704 if ocid :
675705 conda_slug , conda_path = self ._get_conda_info_from_custom_metadata (ocid )
676706 if not conda_path :
677- if not os .path .exists (artifact_directory ) or len (os .listdir (artifact_directory )) == 0 :
678- raise ValueError (f"`artifact_directory` { artifact_directory } does not exist or is empty." )
679- conda_slug , conda_path = self ._get_conda_info_from_runtime (artifact_dir = artifact_directory )
707+ if (
708+ not os .path .exists (artifact_directory )
709+ or len (os .listdir (artifact_directory )) == 0
710+ ):
711+ raise ValueError (
712+ f"`artifact_directory` { artifact_directory } does not exist or is empty."
713+ )
714+ conda_slug , conda_path = self ._get_conda_info_from_runtime (
715+ artifact_dir = artifact_directory
716+ )
680717 if not conda_path or not conda_slug :
681718 raise ValueError ("Conda information cannot be detected." )
682- compartment_id = self .config ["execution" ].get ("compartment_id" , self .config ["infrastructure" ].get ("compartment_id" ))
683- project_id = self .config ["execution" ].get ("project_id" , self .config ["infrastructure" ].get ("project_id" ))
719+ compartment_id = self .config ["execution" ].get (
720+ "compartment_id" , self .config ["infrastructure" ].get ("compartment_id" )
721+ )
722+ project_id = self .config ["execution" ].get (
723+ "project_id" , self .config ["infrastructure" ].get ("project_id" )
724+ )
684725 if not compartment_id or not project_id :
685726 raise ValueError ("`compartment_id` and `project_id` must be provided." )
686- extra_cmd = "/opt/ds/model/deployed_model/ " + data + " " + compartment_id + " " + project_id
727+ extra_cmd = (
728+ DEFAULT_MODEL_DEPLOYMENT_FOLDER
729+ + " "
730+ + data
731+ + " "
732+ + compartment_id
733+ + " "
734+ + project_id
735+ )
687736 bind_volumes = {}
688737 if not is_in_notebook_session ():
689738 bind_volumes = {
@@ -693,16 +742,20 @@ def predict(self) -> None:
693742 }
694743 dir_path = os .path .dirname (os .path .realpath (__file__ ))
695744 script = "script.py"
696- self .config ["execution" ]["source_folder" ] = os .path .abspath (os .path .join (dir_path , ".." ))
745+ self .config ["execution" ]["source_folder" ] = os .path .abspath (
746+ os .path .join (dir_path , ".." )
747+ )
697748 self .config ["execution" ]["entrypoint" ] = script
698- bind_volumes [artifact_directory ] = {"bind" : "/opt/ds/model/deployed_model/" }
749+ bind_volumes [artifact_directory ] = {"bind" : DEFAULT_MODEL_DEPLOYMENT_FOLDER }
699750 if self .config ["execution" ].get ("conda_slug" , conda_slug ):
700751 self .config ["execution" ]["image" ] = ML_JOB_IMAGE
701752 if not self .config ["execution" ].get ("conda_slug" ):
702753 self .config ["execution" ]["conda_slug" ] = conda_slug
703754 self .config ["execution" ]["slug" ] = conda_slug
704755 self .config ["execution" ]["conda_path" ] = conda_path
705- exit_code = self ._run_with_conda_pack (bind_volumes , extra_cmd , install = True , conda_uri = conda_path )
756+ exit_code = self ._run_with_conda_pack (
757+ bind_volumes , extra_cmd , install = True , conda_uri = conda_path
758+ )
706759 else :
707760 raise ValueError ("Either conda pack info or image should be specified." )
708761
@@ -722,14 +775,16 @@ def _get_conda_info_from_custom_metadata(self, ocid):
722775 conda slug and conda path.
723776 """
724777 response = self .client .get_model (ocid )
725- custom_metadata = ModelCustomMetadata ._from_oci_metadata (response .data .custom_metadata_list )
778+ custom_metadata = ModelCustomMetadata ._from_oci_metadata (
779+ response .data .custom_metadata_list
780+ )
726781 conda_slug , conda_path = None , None
727782 if "CondaEnvironmentPath" in custom_metadata .keys :
728- conda_path = custom_metadata [' CondaEnvironmentPath' ].value
783+ conda_path = custom_metadata [" CondaEnvironmentPath" ].value
729784 if "SlugName" in custom_metadata .keys :
730- conda_slug = custom_metadata [' SlugName' ].value
785+ conda_slug = custom_metadata [" SlugName" ].value
731786 return conda_slug , conda_path
732-
787+
733788 @staticmethod
734789 def _get_conda_info_from_runtime (artifact_dir ):
735790 """
@@ -742,7 +797,10 @@ def _get_conda_info_from_runtime(artifact_dir):
742797 """
743798 runtime_yaml_file = os .path .join (artifact_dir , "runtime.yaml" )
744799 runtime_info = RuntimeInfo .from_yaml (uri = runtime_yaml_file )
745- conda_slug = runtime_info .model_deployment .inference_conda_env .inference_env_slug
746- conda_path = runtime_info .model_deployment .inference_conda_env .inference_env_path
800+ conda_slug = (
801+ runtime_info .model_deployment .inference_conda_env .inference_env_slug
802+ )
803+ conda_path = (
804+ runtime_info .model_deployment .inference_conda_env .inference_env_path
805+ )
747806 return conda_slug , conda_path
748-
0 commit comments