2929 GitPythonRuntime ,
3030)
3131from ads .jobs .builders .runtimes .container_runtime import ContainerRuntime
32+ from ads .jobs .builders .runtimes .pytorch_runtime import (
33+ PyTorchDistributedRuntime ,
34+ PyTorchDistributedArtifact ,
35+ )
3236from ads .jobs .builders .runtimes .artifact import (
3337 ScriptArtifact ,
3438 NotebookArtifact ,
3539 PythonArtifact ,
3640 GitPythonArtifact ,
3741)
42+ from ads .opctl .distributed .common import cluster_config_helper
3843from ads .jobs .builders .infrastructure .utils import get_value
44+ from ads .jobs .templates import driver_utils
3945
4046
4147class IncompatibleRuntime (Exception ): # pragma: no cover
@@ -184,7 +190,7 @@ def _translate_config(self, runtime: Runtime) -> dict:
184190 if runtime .args :
185191 # shlex.join() is not available until python 3.8
186192 job_configuration_details ["command_line_arguments" ] = " " .join (
187- shlex .quote (arg ) for arg in runtime .get_spec (runtime .CONST_ARGS )
193+ shlex .quote (str ( arg ) ) for arg in runtime .get_spec (runtime .CONST_ARGS )
188194 )
189195 return job_configuration_details
190196
@@ -653,7 +659,7 @@ def _translate_env(self, runtime: PythonRuntime) -> dict:
653659
654660 if runtime .entrypoint :
655661 envs [self .CONST_CODE_ENTRYPOINT ] = runtime .entrypoint
656- else :
662+ elif runtime . script_uri :
657663 envs [self .CONST_CODE_ENTRYPOINT ] = os .path .basename (runtime .script_uri )
658664
659665 envs [self .CONST_JOB_ENTRYPOINT ] = PythonArtifact .CONST_DRIVER_SCRIPT
@@ -674,9 +680,13 @@ def _extract_envs(self, dsc_job) -> dict:
674680 """
675681 spec = super ()._extract_envs (dsc_job )
676682 envs = spec .pop (PythonRuntime .CONST_ENV_VAR , {})
677- if self .CONST_CODE_ENTRYPOINT not in envs :
683+ if (
684+ self .__class__ == PythonRuntimeHandler
685+ and self .CONST_CODE_ENTRYPOINT not in envs
686+ ):
678687 raise IncompatibleRuntime ()
679- envs .pop (PythonRuntimeHandler .CONST_JOB_ENTRYPOINT )
688+ # PyTorchDistributedRuntime does not require entrypoint.
689+ envs .pop (PythonRuntimeHandler .CONST_JOB_ENTRYPOINT , None )
680690 spec .update (self ._extract_specs (envs , self .SPEC_MAPPINGS ))
681691 if PythonRuntime .CONST_PYTHON_PATH in spec :
682692 spec [PythonRuntime .CONST_PYTHON_PATH ] = spec [
@@ -1035,6 +1045,98 @@ def _extract_envs(self, dsc_job):
10351045 return spec
10361046
10371047
1048+ class PyTorchDistributedRuntimeHandler (PythonRuntimeHandler ):
1049+ RUNTIME_CLASS = PyTorchDistributedRuntime
1050+ CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
1051+ CONST_COMMAND = "OCI__LAUNCH_CMD"
1052+ CONST_DEEPSPEED = "OCI__DEEPSPEED"
1053+
1054+ GIT_SPEC_MAPPINGS = {
1055+ cluster_config_helper .OCI__RUNTIME_URI : GitPythonRuntime .CONST_GIT_URL ,
1056+ cluster_config_helper .OCI__RUNTIME_GIT_BRANCH : GitPythonRuntime .CONST_BRANCH ,
1057+ cluster_config_helper .OCI__RUNTIME_GIT_COMMIT : GitPythonRuntime .CONST_COMMIT ,
1058+ cluster_config_helper .OCI__RUNTIME_GIT_SECRET_ID : GitPythonRuntime .CONST_GIT_SSH_SECRET_ID ,
1059+ }
1060+
1061+ SPEC_MAPPINGS = PythonRuntimeHandler .SPEC_MAPPINGS
1062+ SPEC_MAPPINGS .update (
1063+ {
1064+ PyTorchDistributedRuntime .CONST_COMMAND : CONST_COMMAND ,
1065+ }
1066+ )
1067+
1068+ def _translate_artifact (self , runtime : PyTorchDistributedRuntime ):
1069+ return PyTorchDistributedArtifact (runtime .source_uri , runtime )
1070+
1071+ def _translate_env (self , runtime : PyTorchDistributedRuntime ) -> dict :
1072+ envs = super ()._translate_env (runtime )
1073+ replica = runtime .replica if runtime .replica else 1
1074+ # WORKER_COUNT = REPLICA - 1 so that it will be same as distributed training
1075+ envs [self .CONST_WORKER_COUNT ] = str (replica - 1 )
1076+ envs [self .CONST_JOB_ENTRYPOINT ] = PyTorchDistributedArtifact .CONST_DRIVER_SCRIPT
1077+ if runtime .inputs :
1078+ envs [driver_utils .CONST_ENV_INPUT_MAPPINGS ] = json .dumps (runtime .inputs )
1079+ if runtime .git :
1080+ for env_key , spec_key in self .GIT_SPEC_MAPPINGS .items ():
1081+ if not runtime .git .get (spec_key ):
1082+ continue
1083+ envs [env_key ] = runtime .git [spec_key ]
1084+ if runtime .dependencies :
1085+ if PyTorchDistributedRuntime .CONST_PIP_PKG in runtime .dependencies :
1086+ envs [driver_utils .CONST_ENV_PIP_PKG ] = runtime .dependencies [
1087+ PyTorchDistributedRuntime .CONST_PIP_PKG
1088+ ]
1089+ if PyTorchDistributedRuntime .CONST_PIP_REQ in runtime .dependencies :
1090+ envs [driver_utils .CONST_ENV_PIP_REQ ] = runtime .dependencies [
1091+ PyTorchDistributedRuntime .CONST_PIP_REQ
1092+ ]
1093+ if runtime .use_deepspeed :
1094+ envs [self .CONST_DEEPSPEED ] = "1"
1095+ return envs
1096+
1097+ def _extract_envs (self , dsc_job ) -> dict :
1098+ spec = super ()._extract_envs (dsc_job )
1099+ envs = spec .pop (PythonRuntime .CONST_ENV_VAR , {})
1100+ if self .CONST_WORKER_COUNT not in envs :
1101+ raise IncompatibleRuntime ()
1102+ # Replicas
1103+ spec [PyTorchDistributedRuntime .CONST_REPLICA ] = (
1104+ int (envs .pop (self .CONST_WORKER_COUNT )) + 1
1105+ )
1106+ # Git
1107+ if cluster_config_helper .OCI__RUNTIME_URI in envs :
1108+ git_spec = {}
1109+ for env_key , spec_key in self .GIT_SPEC_MAPPINGS .items ():
1110+ if env_key in envs :
1111+ git_spec [spec_key ] = envs .pop (env_key )
1112+ spec [PyTorchDistributedRuntime .CONST_GIT ] = git_spec
1113+ # Inputs
1114+ input_mappings = envs .pop (driver_utils .CONST_ENV_INPUT_MAPPINGS , None )
1115+ if input_mappings :
1116+ try :
1117+ spec [PyTorchDistributedRuntime .CONST_INPUT ] = json .loads (input_mappings )
1118+ except ValueError :
1119+ spec [PyTorchDistributedRuntime .CONST_INPUT ] = input_mappings
1120+ # Dependencies
1121+ dep = {}
1122+ if driver_utils .CONST_ENV_PIP_PKG in envs :
1123+ dep [PyTorchDistributedRuntime .CONST_PIP_PKG ] = envs .pop (
1124+ driver_utils .CONST_ENV_PIP_PKG
1125+ )
1126+ if driver_utils .CONST_ENV_PIP_REQ in envs :
1127+ dep [PyTorchDistributedRuntime .CONST_PIP_REQ ] = envs .pop (
1128+ driver_utils .CONST_ENV_PIP_REQ
1129+ )
1130+ if dep :
1131+ spec [PyTorchDistributedRuntime .CONST_DEP ] = dep
1132+ if envs .pop (self .CONST_DEEPSPEED , None ):
1133+ spec [PyTorchDistributedRuntime .CONST_DEEPSPEED ] = True
1134+ # Envs
1135+ if envs :
1136+ spec [PythonRuntime .CONST_ENV_VAR ] = envs
1137+ return spec
1138+
1139+
10381140class DataScienceJobRuntimeManager (RuntimeHandler ):
10391141 """This class is used by the DataScienceJob infrastructure to handle the runtime conversion.
10401142 The translate() method determines the actual runtime handler by matching the RUNTIME_CLASS.
@@ -1046,6 +1148,7 @@ class DataScienceJobRuntimeManager(RuntimeHandler):
10461148
10471149 runtime_handlers = [
10481150 ContainerRuntimeHandler ,
1151+ PyTorchDistributedRuntimeHandler ,
10491152 GitPythonRuntimeHandler ,
10501153 NotebookRuntimeHandler ,
10511154 PythonRuntimeHandler ,
0 commit comments