1313import shlex
1414import socket
1515import sys
16+ import traceback
1617
1718import oci
1819import psutil
4041
4142
4243CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
44+ CONST_ENV_JOB_RUN_OCID = "JOB_RUN_OCID"
4345CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
4446CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
4547CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
5153OCI__WORKER_COUNT = "OCI__WORKER_COUNT"
5254DEFAULT_LAUNCHER = "torchrun"
5355
54- set_auth ("resource_principal" )
56+ # Set authentication method to resource principal
57+ # This script is expected to be running inside the job run
58+ if "OCI_RESOURCE_PRINCIPAL_VERSION" in os .environ :
59+ set_auth ("resource_principal" )
60+
61+
62+ class LazyEvaluate :
63+ """This is a class to delay the function call until
64+ its return value is needed for logging purpose.
65+
66+ Example::
67+ logger.debug("The value is %s", LazyEvaluate(the_function, *args, **kwargs))
68+
69+ Python logging will only call the __str__() method when the value is needed.
70+
71+ In the above example, if the log level is INFO or above,
72+ the_function() will not be called/evaluated.
73+ If the log level is DEBUG, the_function will be called,
74+ and if there is an error, the error will be logged.
75+ The program will continue to run even if the error happens during logging.
76+
77+ """
78+
79+ def __init__ (self , func , * args , ** kwargs ) -> None :
80+ self .func = func
81+ self .args = args
82+ self .kwargs = kwargs
83+
84+ def eval (self ):
85+ """Evaluates the function call."""
86+ return self .func (* self .args , ** self .kwargs )
87+
88+ def __str__ (self ) -> str :
89+ """Evaluate the function call and convert the return value as a string."""
90+ try :
91+ val = str (self .eval ())
92+ except Exception as ex :
93+ logger .debug (traceback .format_exc ())
94+ val = f"ERROR: { str (ex )} "
95+ return val
5596
5697
5798class Runner (driver_utils .JobRunner ):
5899 """Base runner class for PyTorch training job"""
59100
101+ # LAUNCHER stores the main command for launching the training job.
102+ # e.g. torchrun, deepspeed, accelerate, etc.
60103 LAUNCHER = ""
61104
62105 def __init__ (self , code_dir : str = driver_utils .DEFAULT_CODE_DIR ) -> None :
@@ -82,7 +125,7 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
82125 else :
83126 # Print the host IP address to logs so that it can be obtained by the nodes.
84127 print (f"{ LOG_PREFIX_HOST_IP } { self .ip } " )
85- self .host_ocid = os .environ [ "JOB_RUN_OCID" ]
128+ self .host_ocid = os .environ . get ( CONST_ENV_JOB_RUN_OCID )
86129 self .host_ip = self .ip
87130 self .is_host = True
88131
@@ -96,10 +139,11 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
96139
97140 logger .debug ("Runner initialized." )
98141
99- def launch_cmd_contains (self , arg ):
142+ def launch_cmd_contains (self , arg ) -> bool :
143+ """Checks if the cmd for launching the training contains specific keyword argument."""
100144 return f"--{ arg } " in self .launch_cmd
101145
102- def wait_for_host_ip_address (self , timeout = 15 * 60 ):
146+ def wait_for_host_ip_address (self , timeout = 15 * 60 ) -> str :
103147 """Waits until the IP address of the host is obtained.
104148
105149 Parameters
@@ -117,7 +161,7 @@ def wait_for_host_ip_address(self, timeout=15 * 60):
117161 self .host_ip = self .wait_for_ip_address (self .host_job_run , timeout )
118162 return self
119163
120- def wait_for_ip_address (self , job_run , timeout = 15 * 60 ):
164+ def wait_for_ip_address (self , job_run , timeout = 15 * 60 ) -> str :
121165 """Waits until the IP address of a particular job run is obtained.
122166
123167 Parameters
@@ -137,11 +181,11 @@ def wait_for_ip_address(self, job_run, timeout=15 * 60):
137181 log_prefix = LOG_PREFIX_HOST_IP
138182 else :
139183 log_prefix = LOG_PREFIX_NODE_IP
140- ip_address = self .wait_for_log (job_run , log_prefix , timeout )
184+ ip_address = self .wait_for_log (job_run , log_prefix , timeout ). strip ()
141185 logger .info ("IP of %s: %s" , job_run .id [- 6 :], ip_address )
142186 return ip_address
143187
144- def wait_for_log (self , job_run , log_prefix , timeout = 15 * 60 ):
188+ def wait_for_log (self , job_run , log_prefix , timeout = 15 * 60 ) -> str :
145189 """Waits until a log message with specific prefix is found in the logs of a job run.
146190
147191 Parameters
@@ -156,12 +200,12 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60):
156200 Returns
157201 -------
158202 str
159- _description_
203+ The log message with out the prefix.
160204
161205 Raises
162206 ------
163207 TimeoutError
164- _description_
208+ Failed to obtain the log message within the specific timeout.
165209 """
166210 logger .debug (
167211 "Waiting for logs with prefix '%s' from %s." , log_prefix , job_run .id
@@ -180,7 +224,21 @@ def wait_for_log(self, job_run, log_prefix, timeout=15 * 60):
180224 return log
181225
182226 @staticmethod
183- def check_job_run_logs (job_run , log_prefix ):
227+ def check_job_run_logs (job_run , log_prefix : str ) -> str :
228+ """Checks the logs of a specific job run and find the log message with specific prefix.
229+
230+ Parameters
231+ ----------
232+ job_run : DataScienceJobRun
233+ The Job run object from which the logs will be obtained.
234+ log_prefix : str
235+ The prefix to look for.
236+
237+ Returns
238+ -------
239+ str
240+ The log message without the prefix.
241+ """
184242 logger .debug ("Checking logs for job run %s" , job_run .id )
185243 logs = job_run .logs ()
186244 for log in logs :
@@ -195,8 +253,10 @@ def find_self_ip(self):
195253 """
196254 hostname = socket .gethostname ()
197255 logger .debug ("Hostname: %s" , hostname )
198- logger .debug ("Get Host by Addr: %s" , socket .gethostbyaddr (socket .gethostname ()))
199- logger .debug ("FQDN: %s" , socket .getfqdn (socket .gethostname ()))
256+ logger .debug (
257+ "Get Host by Addr: %s" , LazyEvaluate (socket .gethostbyaddr , hostname )
258+ )
259+ logger .debug ("FQDN: %s" , LazyEvaluate (socket .getfqdn , hostname ))
200260 if os .environ .get ("JOB_OCID" ):
201261 subnet_id = self .ds_client .get_job (
202262 os .environ ["JOB_OCID" ]
@@ -213,19 +273,20 @@ def find_self_ip(self):
213273 os .environ ["GLOO_SOCKET_IFNAME" ] = interface
214274 os .environ ["NCCL_SOCKET_IFNAME" ] = interface
215275 return ip
216- logger .critical ("Unable to determine node IP address." )
217- return None
276+ raise EnvironmentError ("Unable to determine node IP address." )
218277 else :
219278 ip = socket .gethostbyname (hostname )
220279 logger .info ("Node IP address: %s" , ip )
221280 return ip
222281
223282 def fetch_code (self ):
283+ """Fetches source code from Git if repo uri is specified."""
224284 if cluster_config_helper .OCI__RUNTIME_URI in os .environ :
225285 self ._fetch_git (code_dir = self .code_dir )
226286 return self
227287
228288 def _fetch_git (self , code_dir ):
289+ """Fetches source code from Git repository."""
229290 uri = os .environ .get (cluster_config_helper .OCI__RUNTIME_URI )
230291 branch = os .environ .get (cluster_config_helper .OCI__RUNTIME_GIT_BRANCH )
231292 commit = os .environ .get (cluster_config_helper .OCI__RUNTIME_GIT_COMMIT )
@@ -236,22 +297,52 @@ def _fetch_git(self, code_dir):
236297 branch = branch , commit = commit
237298 )
238299
239- def get_entrypoint_with_args (self , prefix = "" ):
300+ def get_cmd_with_entrypoint_and_args (self , prefix : str = "" ) -> str :
301+ """Gets the command based on entrypoint and arguments.
302+
303+ Parameters
304+ ----------
305+ prefix : str, optional
306+ Command prefix, by default ""
307+ This can be used to set environment variables for the command.
308+ e.g. ENV=1 command
309+
310+ Returns
311+ -------
312+ str
313+ The command including the prefix, entrypoint and arguments.
314+ """
240315 cmd = os .environ [self .entrypoint_env ]
241316 if prefix :
242317 cmd = prefix + " " + cmd
243318 if sys .argv [1 :]:
244319 cmd += " " + " " .join (sys .argv [1 :])
245320 return cmd
246321
247- def prepare_cmd (self , launch_args , prefix = None ):
322+ def prepare_cmd (self , launch_args : list = None , prefix = "" ):
323+ """Prepares the command for starting the training.
324+
325+ Parameters
326+ ----------
327+ launch_args : list
328+ The command and arguments for starting the training as a list.
329+ prefix : str, optional
330+ The prefix to be added to the launch_args in the command, by default ""
331+ This can be used to set environment variables for the command.
332+ e.g. ENV=1 command
333+
334+ Returns
335+ -------
336+ str
337+ The command for starting the training.
338+ """
248339 if not launch_args :
249340 launch_args = []
250341 # Append launch cmd args specified by the user.
251342 if self .launch_cmd :
252343 launch_args .append (self .launch_cmd [len (self .LAUNCHER ) + 1 :])
253344 else :
254- launch_args .append (self .get_entrypoint_with_args ())
345+ launch_args .append (self .get_cmd_with_entrypoint_and_args ())
255346
256347 if prefix :
257348 launcher = f"{ prefix } { self .LAUNCHER } "
@@ -581,7 +672,7 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
581672 # --multi_gpu will be set automatically if there is more than 1 GPU
582673 # self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
583674 self .num_machines = self .node_count
584- self .machine_rank = os .environ ["OCI__NODE_RANK " ]
675+ self .machine_rank = os .environ ["RANK " ]
585676 # Total number of processes across all nodes
586677 # Here we assume all nodes are having the same shape
587678 self .num_processes = (self .gpu_count if self .gpu_count else 1 ) * self .node_count
0 commit comments