4343CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
4444CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
4545CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
46- CONST_ENV_LAUNCH_ARGS = "OCI__LAUNCH_ARGS "
46+ CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED "
4747LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
4848LOG_PREFIX_NODE_IP = "Node IP: "
4949LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
@@ -246,6 +246,8 @@ def get_entrypoint_with_args(self, prefix=""):
246246 return cmd
247247
248248 def prepare_cmd (self , launch_args , prefix = None ):
249+ if not launch_args :
250+ launch_args = []
249251 # Append launch cmd args specified by the user.
250252 if self .launch_cmd :
251253 launch_args .append (self .launch_cmd [len (self .LAUNCHER ) + 1 :])
@@ -486,67 +488,90 @@ def save_deepspeed_env(self):
486488 logger .debug ("Environment variables saved to %s" , self .ENV_FILE )
487489 self .run_command (f"cat { self .ENV_FILE } " )
488490
489- def run (self ):
490- if self .is_host :
491- self .generate_key_pair ().generate_hostfile ()
492- self .save_deepspeed_env ()
493- # Wait for nodes to be ready
494- for run in self .node_runs :
495- self .wait_for_log (run , LOG_PREFIX_PUBLIC_KEY )
491+ def run_deepspeed_host (self , launch_args = None ):
492+ """Prepares the host and launch the deepspeed training.
496493
497- for node_ip in self .node_ip_list :
498- self .run_command (
499- f"ssh-keyscan -H { node_ip } >> { SSH_DIR } /known_hosts" ,
500- level = logging .DEBUG ,
501- check = True ,
502- )
503- # For DeepSpeed, we only need to run the cmd on the host
504- launch_args = [f"--hostfile={ self .HOST_FILE } " ]
494+ Parameters
495+ ----------
496+ launch_args : str, optional
497+ Additional command line arguments, by default None.
498+ The deepspeed host file should be specified in the launch args.
499+ For "deepspeed": --hostfile
500+ For "accelerate launch": --deepspeed_hostfile
501+ """
502+ self .generate_key_pair ().generate_hostfile ()
503+ self .save_deepspeed_env ()
504+ # Wait for nodes to be ready
505+ for run in self .node_runs :
506+ self .wait_for_log (run , LOG_PREFIX_PUBLIC_KEY )
507+
508+ for node_ip in self .node_ip_list :
509+ self .run_command (
510+ f"ssh-keyscan -H { node_ip } >> { SSH_DIR } /known_hosts" ,
511+ level = logging .DEBUG ,
512+ check = True ,
513+ )
505514
506- cmd = self .prepare_cmd (launch_args )
515+ cmd = self .prepare_cmd (launch_args )
516+ # For DeepSpeed, we only need to run the cmd on the host
517+ try :
518+ self .time_cmd (cmd )
519+ except :
520+ # Caution: file will not be generated if job run is killed from the console.
521+ self .touch_file (self .ERROR_FILE )
522+ raise
523+ # Signal stop
524+ self .touch_file (self .STOP_FILE )
525+
526+ def run_deepspeed_worker (self ):
527+ self .fetch_host_public_key ()
528+ # Keep the job run alive until host job run is finished.
529+ while not os .path .exists (self .STOP_FILE ):
530+ time .sleep (60 )
531+ # Stop the node if the host touched the error file.
532+ if os .path .exists (self .ERROR_FILE ):
533+ logger .error ("There is an error in the host job run." )
534+ sys .exit (1 )
535+ # Stop the node if the host job run is CANCELLED or in unexpected state.
536+ self .host_job_run .sync ()
537+ if self .host_job_run .status not in [
538+ "ACCEPTED" ,
539+ "IN_PROGRESS" ,
540+ "SUCCEEDED" ,
541+ ]:
542+ logger .info (
543+ "Host job run status is %s. Stopping job run..." ,
544+ self .host_job_run .status ,
545+ )
546+ sys .exit (2 )
547+ logger .info ("Job finished successfully. Stopping job run..." )
507548
508- try :
509- self .time_cmd (cmd )
510- except :
511- # Caution: file will not be generated if job run is killed from the console.
512- self .touch_file (self .ERROR_FILE )
513- raise
514- # Signal stop
515- self .touch_file (self .STOP_FILE )
549+ def run (self ):
550+ if self .is_host :
551+ launch_args = [f"--hostfile={ self .HOST_FILE } " ]
552+ self .run_deepspeed_host (launch_args )
516553 else :
517- self .fetch_host_public_key ()
518- # Keep the job run alive until host job run is finished.
519- while not os .path .exists (self .STOP_FILE ):
520- time .sleep (60 )
521- # Stop the node if the host touched the error file.
522- if os .path .exists (self .ERROR_FILE ):
523- logger .error ("There is an error in the host job run." )
524- sys .exit (1 )
525- # Stop the node if the host job run is CANCELLED or in unexpected state.
526- self .host_job_run .sync ()
527- if self .host_job_run .status not in [
528- "ACCEPTED" ,
529- "IN_PROGRESS" ,
530- "SUCCEEDED" ,
531- ]:
532- logger .info (
533- "Host job run status is %s. Stopping job run..." ,
534- self .host_job_run .status ,
535- )
536- sys .exit (2 )
537- logger .info ("Job finished successfully. Stopping job run..." )
554+ self .run_deepspeed_worker ()
538555
539556
540557class AccelerateRunner (TorchRunner , DeepSpeedRunner ):
541- DEFAULT_ARGS = ["num_processes" , "num_machines" , "machine_rank" ]
542- TORCHRUN_ARGS = ["main_process_ip" , "main_process_port" ]
558+ # accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
559+ # https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
560+ DEFAULT_ARGS = [
561+ "num_processes" ,
562+ "num_machines" ,
563+ "machine_rank" ,
564+ "main_process_ip" ,
565+ "main_process_port" ,
566+ ]
567+ TORCHRUN_ARGS = []
543568 LAUNCHER = "accelerate launch"
544569
545570 def __init__ (self , code_dir : str = driver_utils .DEFAULT_CODE_DIR ) -> None :
546571 super ().__init__ (code_dir )
547572 # For "accelerate launch", only one of the following options can be used at one time
548573 # `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`.
549- # When a config file is not provided,
574+ # When a config file is not provided,
550575 # --multi_gpu will be set automatically if there is more than 1 GPU
551576 # self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
552577 self .num_machines = self .node_count
@@ -560,7 +585,9 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
560585 self .main_process_ip = None
561586
562587 def use_deepspeed (self ):
563- return self .launch_cmd_contains ("use_deepspeed" )
588+ return os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
589+ "use_deepspeed"
590+ )
564591
565592 def accelerate_args (self ):
566593 args = []
@@ -584,13 +611,18 @@ def run_with_torchrun(self):
584611 self .time_cmd (cmd = cmd )
585612
586613 def run_with_deepspeed (self ):
587- raise NotImplementedError
614+ if self .is_host :
615+ launch_args = self .accelerate_args ()
616+ launch_args .append (f"--deepspeed_hostfile={ self .HOST_FILE } " )
617+ self .run_deepspeed_host (launch_args )
618+ else :
619+ self .run_deepspeed_worker ()
588620
589621 def run (self ):
590622 # Check if any default argument is provided by the user
591623 for arg in self .DEFAULT_ARGS :
592624 if self .launch_cmd_contains (arg ):
593- logger .debug ("%s found in launch args ." , arg )
625+ logger .debug ("%s found in command ." , arg )
594626 setattr (self , arg , None )
595627 if self .use_deepspeed ():
596628 self .run_with_deepspeed ()
0 commit comments