Skip to content

Commit 4d49e56

Browse files
committed
Update dsc_job_runtime.py and driver_pytorch.py
1 parent 958657a commit 4d49e56

File tree

2 files changed

+110
-19
lines changed

2 files changed

+110
-19
lines changed

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1071,7 +1071,7 @@ def _translate_artifact(self, runtime: PyTorchDistributedRuntime):
10711071
def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
10721072
envs = super()._translate_env(runtime)
10731073
replica = runtime.replica if runtime.replica else 1
1074-
# WORKER_COUNT = REPLICA - 1
1074+
# WORKER_COUNT = REPLICA - 1 so that it will be same as distributed training
10751075
envs[self.CONST_WORKER_COUNT] = str(replica - 1)
10761076
envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT
10771077
if runtime.inputs:

ads/jobs/templates/driver_pytorch.py

Lines changed: 109 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import shlex
1414
import socket
1515
import sys
16+
import traceback
1617

1718
import oci
1819
import psutil
@@ -40,6 +41,7 @@
4041

4142

4243
CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
44+
CONST_ENV_JOB_RUN_OCID = "JOB_RUN_OCID"
4345
CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
4446
CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
4547
CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
@@ -51,12 +53,53 @@
5153
OCI__WORKER_COUNT = "OCI__WORKER_COUNT"
5254
DEFAULT_LAUNCHER = "torchrun"
5355

54-
set_auth("resource_principal")
56+
# Set authentication method to resource principal
57+
# This script is expected to be running inside the job run
58+
if "OCI_RESOURCE_PRINCIPAL_VERSION" in os.environ:
59+
set_auth("resource_principal")
60+
61+
62+
class LazyEvaluate:
63+
"""This is a class to delay the function call until
64+
its return value is needed for logging purpose.
65+
66+
Example::
67+
logger.debug("The value is %s", LazyEvaluate(the_function, *args, **kwargs))
68+
69+
Python logging will only call the __str__() method when the value is needed.
70+
71+
In the above example, if the log level is INFO or above,
72+
the_function() will not be called/evaluated.
73+
If the log level is DEBUG, the_function will be called,
74+
and if there is an error, the error will be logged.
75+
The program will continue to run even if the error happens during logging.
76+
77+
"""
78+
79+
def __init__(self, func, *args, **kwargs) -> None:
80+
self.func = func
81+
self.args = args
82+
self.kwargs = kwargs
83+
84+
def eval(self):
85+
"""Evaluates the function call."""
86+
return self.func(*self.args, **self.kwargs)
87+
88+
def __str__(self) -> str:
89+
"""Evaluate the function call and convert the return value as a string."""
90+
try:
91+
val = str(self.eval())
92+
except Exception as ex:
93+
logger.debug(traceback.format_exc())
94+
val = f"ERROR: {str(ex)}"
95+
return val
5596

5697

5798
class Runner(driver_utils.JobRunner):
5899
"""Base runner class for PyTorch training job"""
59100

101+
# LAUNCHER stores the main command for launching the training job.
102+
# e.g. torchrun, deepspeed, accelerate, etc.
60103
LAUNCHER = ""
61104

62105
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
@@ -82,7 +125,7 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
82125
else:
83126
# Print the host IP address to logs so that it can be obtained by the nodes.
84127
print(f"{LOG_PREFIX_HOST_IP}{self.ip}")
85-
self.host_ocid = os.environ["JOB_RUN_OCID"]
128+
self.host_ocid = os.environ.get(CONST_ENV_JOB_RUN_OCID)
86129
self.host_ip = self.ip
87130
self.is_host = True
88131

@@ -96,10 +139,11 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
96139

97140
logger.debug("Runner initialized.")
98141

99-
def launch_cmd_contains(self, arg):
142+
def launch_cmd_contains(self, arg) -> bool:
143+
"""Checks if the cmd for launching the training contains specific keyword argument."""
100144
return f"--{arg}" in self.launch_cmd
101145

102-
def wait_for_host_ip_address(self, timeout=15 * 60):
146+
def wait_for_host_ip_address(self, timeout=15 * 60) -> str:
103147
"""Waits until the IP address of the host is obtained.
104148
105149
Parameters
@@ -117,7 +161,7 @@ def wait_for_host_ip_address(self, timeout=15 * 60):
117161
self.host_ip = self.wait_for_ip_address(self.host_job_run, timeout)
118162
return self
119163

120-
def wait_for_ip_address(self, job_run, timeout=15 * 60):
164+
def wait_for_ip_address(self, job_run, timeout=15 * 60) -> str:
121165
"""Waits until the IP address of a particular job run is obtained.
122166
123167
Parameters
@@ -137,11 +181,11 @@ def wait_for_ip_address(self, job_run, timeout=15 * 60):
137181
log_prefix = LOG_PREFIX_HOST_IP
138182
else:
139183
log_prefix = LOG_PREFIX_NODE_IP
140-
ip_address = self.wait_for_log(job_run, log_prefix, timeout)
184+
ip_address = self.wait_for_log(job_run, log_prefix, timeout).strip()
141185
logger.info("IP of %s: %s", job_run.id[-6:], ip_address)
142186
return ip_address
143187

144-
def wait_for_log(self, job_run, log_prefix, timeout=15 * 60):
188+
def wait_for_log(self, job_run, log_prefix, timeout=15 * 60) -> str:
145189
"""Waits until a log message with specific prefix is found in the logs of a job run.
146190
147191
Parameters
@@ -156,12 +200,12 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60):
156200
Returns
157201
-------
158202
str
159-
_description_
203+
The log message with out the prefix.
160204
161205
Raises
162206
------
163207
TimeoutError
164-
_description_
208+
Failed to obtain the log message within the specific timeout.
165209
"""
166210
logger.debug(
167211
"Waiting for logs with prefix '%s' from %s.", log_prefix, job_run.id
@@ -180,7 +224,21 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60):
180224
return log
181225

182226
@staticmethod
183-
def check_job_run_logs(job_run, log_prefix):
227+
def check_job_run_logs(job_run, log_prefix: str) -> str:
228+
"""Checks the logs of a specific job run and find the log message with specific prefix.
229+
230+
Parameters
231+
----------
232+
job_run : DataScienceJobRun
233+
The Job run object from which the logs will be obtained.
234+
log_prefix : str
235+
The prefix to look for.
236+
237+
Returns
238+
-------
239+
str
240+
The log message without the prefix.
241+
"""
184242
logger.debug("Checking logs for job run %s", job_run.id)
185243
logs = job_run.logs()
186244
for log in logs:
@@ -195,8 +253,10 @@ def find_self_ip(self):
195253
"""
196254
hostname = socket.gethostname()
197255
logger.debug("Hostname: %s", hostname)
198-
logger.debug("Get Host by Addr: %s", socket.gethostbyaddr(socket.gethostname()))
199-
logger.debug("FQDN: %s", socket.getfqdn(socket.gethostname()))
256+
logger.debug(
257+
"Get Host by Addr: %s", LazyEvaluate(socket.gethostbyaddr, hostname)
258+
)
259+
logger.debug("FQDN: %s", LazyEvaluate(socket.getfqdn, hostname))
200260
if os.environ.get("JOB_OCID"):
201261
subnet_id = self.ds_client.get_job(
202262
os.environ["JOB_OCID"]
@@ -213,19 +273,20 @@ def find_self_ip(self):
213273
os.environ["GLOO_SOCKET_IFNAME"] = interface
214274
os.environ["NCCL_SOCKET_IFNAME"] = interface
215275
return ip
216-
logger.critical("Unable to determine node IP address.")
217-
return None
276+
raise EnvironmentError("Unable to determine node IP address.")
218277
else:
219278
ip = socket.gethostbyname(hostname)
220279
logger.info("Node IP address: %s", ip)
221280
return ip
222281

223282
def fetch_code(self):
283+
"""Fetches source code from Git if repo uri is specified."""
224284
if cluster_config_helper.OCI__RUNTIME_URI in os.environ:
225285
self._fetch_git(code_dir=self.code_dir)
226286
return self
227287

228288
def _fetch_git(self, code_dir):
289+
"""Fetches source code from Git repository."""
229290
uri = os.environ.get(cluster_config_helper.OCI__RUNTIME_URI)
230291
branch = os.environ.get(cluster_config_helper.OCI__RUNTIME_GIT_BRANCH)
231292
commit = os.environ.get(cluster_config_helper.OCI__RUNTIME_GIT_COMMIT)
@@ -236,22 +297,52 @@ def _fetch_git(self, code_dir):
236297
branch=branch, commit=commit
237298
)
238299

239-
def get_entrypoint_with_args(self, prefix=""):
300+
def get_cmd_with_entrypoint_and_args(self, prefix: str = "") -> str:
301+
"""Gets the command based on entrypoint and arguments.
302+
303+
Parameters
304+
----------
305+
prefix : str, optional
306+
Command prefix, by default ""
307+
This can be used to set environment variables for the command.
308+
e.g. ENV=1 command
309+
310+
Returns
311+
-------
312+
str
313+
The command including the prefix, entrypoint and arguments.
314+
"""
240315
cmd = os.environ[self.entrypoint_env]
241316
if prefix:
242317
cmd = prefix + " " + cmd
243318
if sys.argv[1:]:
244319
cmd += " " + " ".join(sys.argv[1:])
245320
return cmd
246321

247-
def prepare_cmd(self, launch_args, prefix=None):
322+
def prepare_cmd(self, launch_args: list = None, prefix=""):
323+
"""Prepares the command for starting the training.
324+
325+
Parameters
326+
----------
327+
launch_args : list
328+
The command and arguments for starting the training as a list.
329+
prefix : str, optional
330+
The prefix to be added to the launch_args in the command, by default ""
331+
This can be used to set environment variables for the command.
332+
e.g. ENV=1 command
333+
334+
Returns
335+
-------
336+
str
337+
The command for starting the training.
338+
"""
248339
if not launch_args:
249340
launch_args = []
250341
# Append launch cmd args specified by the user.
251342
if self.launch_cmd:
252343
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
253344
else:
254-
launch_args.append(self.get_entrypoint_with_args())
345+
launch_args.append(self.get_cmd_with_entrypoint_and_args())
255346

256347
if prefix:
257348
launcher = f"{prefix} {self.LAUNCHER}"
@@ -581,7 +672,7 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
581672
# --multi_gpu will be set automatically if there is more than 1 GPU
582673
# self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
583674
self.num_machines = self.node_count
584-
self.machine_rank = os.environ["OCI__NODE_RANK"]
675+
self.machine_rank = os.environ["RANK"]
585676
# Total number of processes across all nodes
586677
# Here we assume all nodes are having the same shape
587678
self.num_processes = (self.gpu_count if self.gpu_count else 1) * self.node_count

0 commit comments

Comments
 (0)