@@ -673,7 +673,49 @@ def run(self):
673673 self .run_deepspeed_worker ()
674674
675675
676+ class GenericRunner (TorchRunner , DeepSpeedRunner ):
677+ """Runner for running command other than ``torchrun``, ``deepspeed`` or ``accelerate``."""
678+
679+ def use_deepspeed (self ) -> bool :
680+ """Indicate if DeepSpeed is used."""
681+ if os .environ .get (CONST_ENV_DEEPSPEED ):
682+ return True
683+ return False
684+
685+ def set_env_var (self ):
686+ """Set default environment variables."""
687+ defaults = {
688+ "WORLD_SIZE" : self .node_count ,
689+ "MASTER_ADDR" : self .host_ip ,
690+ "MASTER_PORT" : self .RDZV_PORT ,
691+ }
692+ for k , v in defaults .items ():
693+ if k not in os .environ :
694+ os .environ [k ] = v
695+
696+ def run (self ):
697+ """Runs the user's command.
698+ Note that for TorchRunner or DeepSpeedRunner,
699+ we automatically add arguments for some settings,
700+ like the number of nodes and the host node address.
701+
702+ This generic runner does not modify the command specified by the user.
703+ User needs to make sure the command can work on all nodes.
704+ User may use the environment variables in the command.
705+ """
706+ self .set_env_var ()
707+ if self .use_deepspeed ():
708+ if self .is_host :
709+ self .run_deepspeed_host ()
710+ else :
711+ self .run_deepspeed_worker ()
712+ else :
713+ self .time_cmd (cmd = self .prepare_cmd (prefix = self .env_ld_preload ()))
714+
715+
676716class AccelerateRunner (TorchRunner , DeepSpeedRunner ):
717+ """Runner for HuggingFace Accelerate."""
718+
677719 # accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
678720 # https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
679721 DEFAULT_ARGS = [
@@ -704,11 +746,18 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
704746 self .main_process_ip = None
705747
706748 def use_deepspeed (self ):
707- return os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
749+ """Indicate if DeepSpeed is used."""
750+ # Accelerate support using DeepSpeed by adding the "--use_deepspeed" argument.
751+ if os .environ .get (CONST_ENV_DEEPSPEED ) or self .launch_cmd_contains (
708752 "use_deepspeed"
709- )
753+ ):
754+ return True
755+ return False
710756
711757 def accelerate_args (self ):
758+ """Gets the default arguments for the accelerate command.
759+ The value of the default arguments are assigned in ``__init__()``.
760+ """
712761 args = []
713762 for arg in self .DEFAULT_ARGS :
714763 arg_val = getattr (self , arg , None )
@@ -720,6 +769,7 @@ def accelerate_args(self):
720769 return args
721770
722771 def run_with_torchrun (self ):
772+ """Runs the job with torchrun."""
723773 launch_args = self .accelerate_args ()
724774 for arg in self .TORCHRUN_ARGS :
725775 if not self .launch_cmd_contains (arg ):
@@ -728,6 +778,7 @@ def run_with_torchrun(self):
728778 self .time_cmd (cmd = cmd )
729779
730780 def run_with_deepspeed (self ):
781+ """Runs the job with DeepSpeed."""
731782 if self .is_host :
732783 launch_args = self .accelerate_args ()
733784 if self .num_machines > 1 :
@@ -758,6 +809,8 @@ def main():
758809 runner_class = DeepSpeedRunner
759810 elif launch_cmd .startswith ("accelerate " ):
760811 runner_class = AccelerateRunner
812+ else :
813+ runner_class = GenericRunner
761814
762815 runner = runner_class ()
763816 runner : Runner
0 commit comments