Skip to content

Commit bde06d2

Browse files
committed
Support deepspeed with accelerate launch.
1 parent 0e78ad3 commit bde06d2

File tree

3 files changed

+105
-54
lines changed

3 files changed

+105
-54
lines changed

ads/jobs/builders/infrastructure/dsc_job_runtime.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1045,6 +1045,7 @@ class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
10451045
RUNTIME_CLASS = PyTorchDistributedRuntime
10461046
CONST_WORKER_COUNT = "OCI__WORKER_COUNT"
10471047
CONST_COMMAND = "OCI__LAUNCH_CMD"
1048+
CONST_DEEPSPEED = "OCI__DEEPSPEED"
10481049

10491050
GIT_SPEC_MAPPINGS = {
10501051
cluster_config_helper.OCI__RUNTIME_URI: GitPythonRuntime.CONST_GIT_URL,
@@ -1054,14 +1055,19 @@ class PyTorchDistributedRuntimeHandler(PythonRuntimeHandler):
10541055
}
10551056

10561057
SPEC_MAPPINGS = PythonRuntimeHandler.SPEC_MAPPINGS
1057-
SPEC_MAPPINGS.update({PyTorchDistributedRuntime.CONST_COMMAND: CONST_COMMAND})
1058+
SPEC_MAPPINGS.update(
1059+
{
1060+
PyTorchDistributedRuntime.CONST_COMMAND: CONST_COMMAND,
1061+
}
1062+
)
10581063

10591064
def _translate_artifact(self, runtime: PyTorchDistributedRuntime):
10601065
return PyTorchDistributedArtifact(runtime.source_uri, runtime)
10611066

10621067
def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
10631068
envs = super()._translate_env(runtime)
10641069
replica = runtime.replica if runtime.replica else 1
1070+
# WORKER_COUNT = REPLICA - 1
10651071
envs[self.CONST_WORKER_COUNT] = str(replica - 1)
10661072
envs[self.CONST_JOB_ENTRYPOINT] = PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT
10671073
if runtime.inputs:
@@ -1080,6 +1086,8 @@ def _translate_env(self, runtime: PyTorchDistributedRuntime) -> dict:
10801086
envs[driver_utils.CONST_ENV_PIP_REQ] = runtime.dependencies[
10811087
PyTorchDistributedRuntime.CONST_PIP_REQ
10821088
]
1089+
if runtime.use_deepspeed:
1090+
envs[self.CONST_DEEPSPEED] = "1"
10831091
return envs
10841092

10851093
def _extract_envs(self, dsc_job) -> dict:
@@ -1117,6 +1125,8 @@ def _extract_envs(self, dsc_job) -> dict:
11171125
)
11181126
if dep:
11191127
spec[PyTorchDistributedRuntime.CONST_DEP] = dep
1128+
if envs.pop(self.CONST_DEEPSPEED, None):
1129+
spec[PyTorchDistributedRuntime.CONST_DEEPSPEED] = True
11201130
# Envs
11211131
if envs:
11221132
spec[PythonRuntime.CONST_ENV_VAR] = envs

ads/jobs/builders/runtimes/pytorch_runtime.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ class PyTorchDistributedRuntime(PythonRuntime):
1414
CONST_PIP_REQ = "pipRequirements"
1515
CONST_PIP_PKG = "pipPackages"
1616
CONST_COMMAND = "command"
17+
CONST_DEEPSPEED = "deepspeed"
1718

1819
def with_git(
1920
self, url: str, branch: str = None, commit: str = None, secret_ocid: str = None
@@ -77,13 +78,21 @@ def with_dependency(self, pip_req=None, pip_pkg=None):
7778
def dependencies(self) -> dict:
7879
return self.get_spec(self.CONST_DEP)
7980

80-
def with_command(self, command: str):
81+
def with_command(self, command: str, use_deepspeed=False):
82+
if use_deepspeed:
83+
self.set_spec(self.CONST_DEEPSPEED, True)
8184
return self.set_spec(self.CONST_COMMAND, command)
8285

8386
@property
8487
def command(self):
8588
return self.get_spec(self.CONST_COMMAND)
8689

90+
@property
91+
def use_deepspeed(self):
92+
if self.get_spec(self.CONST_DEEPSPEED):
93+
return True
94+
return False
95+
8796
def run(self, dsc_job, **kwargs):
8897
replicas = self.replica if self.replica else 1
8998
main_run = None

ads/jobs/templates/driver_pytorch.py

Lines changed: 84 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@
4343
CONST_ENV_HOST_JOB_RUN_OCID = "MAIN_JOB_RUN_OCID"
4444
CONST_ENV_LD_PRELOAD = "LD_PRELOAD"
4545
CONST_ENV_LAUNCH_CMD = "OCI__LAUNCH_CMD"
46-
CONST_ENV_LAUNCH_ARGS = "OCI__LAUNCH_ARGS"
46+
CONST_ENV_DEEPSPEED = "OCI__DEEPSPEED"
4747
LOG_PREFIX_HOST_IP = "Distributed Training HOST IP: "
4848
LOG_PREFIX_NODE_IP = "Node IP: "
4949
LOG_PREFIX_PUBLIC_KEY = "HOST PUBLIC KEY: "
@@ -246,6 +246,8 @@ def get_entrypoint_with_args(self, prefix=""):
246246
return cmd
247247

248248
def prepare_cmd(self, launch_args, prefix=None):
249+
if not launch_args:
250+
launch_args = []
249251
# Append launch cmd args specified by the user.
250252
if self.launch_cmd:
251253
launch_args.append(self.launch_cmd[len(self.LAUNCHER) + 1 :])
@@ -486,67 +488,90 @@ def save_deepspeed_env(self):
486488
logger.debug("Environment variables saved to %s", self.ENV_FILE)
487489
self.run_command(f"cat {self.ENV_FILE}")
488490

489-
def run(self):
490-
if self.is_host:
491-
self.generate_key_pair().generate_hostfile()
492-
self.save_deepspeed_env()
493-
# Wait for nodes to be ready
494-
for run in self.node_runs:
495-
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
491+
def run_deepspeed_host(self, launch_args=None):
492+
"""Prepares the host and launch the deepspeed training.
496493
497-
for node_ip in self.node_ip_list:
498-
self.run_command(
499-
f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts",
500-
level=logging.DEBUG,
501-
check=True,
502-
)
503-
# For DeepSpeed, we only need to run the cmd on the host
504-
launch_args = [f"--hostfile={self.HOST_FILE}"]
494+
Parameters
495+
----------
496+
launch_args : str, optional
497+
Additional command line arguments, by default None.
498+
The deepspeed host file should be specified in the launch args.
499+
For "deepspeed": --hostfile
500+
For "accelerate launch": --deepspeed_hostfile
501+
"""
502+
self.generate_key_pair().generate_hostfile()
503+
self.save_deepspeed_env()
504+
# Wait for nodes to be ready
505+
for run in self.node_runs:
506+
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
507+
508+
for node_ip in self.node_ip_list:
509+
self.run_command(
510+
f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts",
511+
level=logging.DEBUG,
512+
check=True,
513+
)
505514

506-
cmd = self.prepare_cmd(launch_args)
515+
cmd = self.prepare_cmd(launch_args)
516+
# For DeepSpeed, we only need to run the cmd on the host
517+
try:
518+
self.time_cmd(cmd)
519+
except:
520+
# Caution: file will not be generated if job run is killed from the console.
521+
self.touch_file(self.ERROR_FILE)
522+
raise
523+
# Signal stop
524+
self.touch_file(self.STOP_FILE)
525+
526+
def run_deepspeed_worker(self):
527+
self.fetch_host_public_key()
528+
# Keep the job run alive until host job run is finished.
529+
while not os.path.exists(self.STOP_FILE):
530+
time.sleep(60)
531+
# Stop the node if the host touched the error file.
532+
if os.path.exists(self.ERROR_FILE):
533+
logger.error("There is an error in the host job run.")
534+
sys.exit(1)
535+
# Stop the node if the host job run is CANCELLED or in unexpected state.
536+
self.host_job_run.sync()
537+
if self.host_job_run.status not in [
538+
"ACCEPTED",
539+
"IN_PROGRESS",
540+
"SUCCEEDED",
541+
]:
542+
logger.info(
543+
"Host job run status is %s. Stopping job run...",
544+
self.host_job_run.status,
545+
)
546+
sys.exit(2)
547+
logger.info("Job finished successfully. Stopping job run...")
507548

508-
try:
509-
self.time_cmd(cmd)
510-
except:
511-
# Caution: file will not be generated if job run is killed from the console.
512-
self.touch_file(self.ERROR_FILE)
513-
raise
514-
# Signal stop
515-
self.touch_file(self.STOP_FILE)
549+
def run(self):
550+
if self.is_host:
551+
launch_args = [f"--hostfile={self.HOST_FILE}"]
552+
self.run_deepspeed_host(launch_args)
516553
else:
517-
self.fetch_host_public_key()
518-
# Keep the job run alive until host job run is finished.
519-
while not os.path.exists(self.STOP_FILE):
520-
time.sleep(60)
521-
# Stop the node if the host touched the error file.
522-
if os.path.exists(self.ERROR_FILE):
523-
logger.error("There is an error in the host job run.")
524-
sys.exit(1)
525-
# Stop the node if the host job run is CANCELLED or in unexpected state.
526-
self.host_job_run.sync()
527-
if self.host_job_run.status not in [
528-
"ACCEPTED",
529-
"IN_PROGRESS",
530-
"SUCCEEDED",
531-
]:
532-
logger.info(
533-
"Host job run status is %s. Stopping job run...",
534-
self.host_job_run.status,
535-
)
536-
sys.exit(2)
537-
logger.info("Job finished successfully. Stopping job run...")
554+
self.run_deepspeed_worker()
538555

539556

540557
class AccelerateRunner(TorchRunner, DeepSpeedRunner):
541-
DEFAULT_ARGS = ["num_processes", "num_machines", "machine_rank"]
542-
TORCHRUN_ARGS = ["main_process_ip", "main_process_port"]
558+
# accelerate launch will add main_process_port for deepspeed cmd even if it is not needed.
559+
# https://github.com/huggingface/accelerate/blob/70920895e80f78d96d8f91e0beeb3ebdb8e5e5d6/src/accelerate/utils/launch.py#L233
560+
DEFAULT_ARGS = [
561+
"num_processes",
562+
"num_machines",
563+
"machine_rank",
564+
"main_process_ip",
565+
"main_process_port",
566+
]
567+
TORCHRUN_ARGS = []
543568
LAUNCHER = "accelerate launch"
544569

545570
def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
546571
super().__init__(code_dir)
547572
# For "accelerate launch", only one of the following options can be used at one time
548573
# `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp`.
549-
# When a config file is not provided,
574+
# When a config file is not provided,
550575
# --multi_gpu will be set automatically if there is more than 1 GPU
551576
# self.multi_gpu = bool(self.node_count > 1 or self.gpu_count > 1)
552577
self.num_machines = self.node_count
@@ -560,7 +585,9 @@ def __init__(self, code_dir: str = driver_utils.DEFAULT_CODE_DIR) -> None:
560585
self.main_process_ip = None
561586

562587
def use_deepspeed(self):
563-
return self.launch_cmd_contains("use_deepspeed")
588+
return os.environ.get(CONST_ENV_DEEPSPEED) or self.launch_cmd_contains(
589+
"use_deepspeed"
590+
)
564591

565592
def accelerate_args(self):
566593
args = []
@@ -584,13 +611,18 @@ def run_with_torchrun(self):
584611
self.time_cmd(cmd=cmd)
585612

586613
def run_with_deepspeed(self):
587-
raise NotImplementedError
614+
if self.is_host:
615+
launch_args = self.accelerate_args()
616+
launch_args.append(f"--deepspeed_hostfile={self.HOST_FILE}")
617+
self.run_deepspeed_host(launch_args)
618+
else:
619+
self.run_deepspeed_worker()
588620

589621
def run(self):
590622
# Check if any default argument is provided by the user
591623
for arg in self.DEFAULT_ARGS:
592624
if self.launch_cmd_contains(arg):
593-
logger.debug("%s found in launch args.", arg)
625+
logger.debug("%s found in command.", arg)
594626
setattr(self, arg, None)
595627
if self.use_deepspeed():
596628
self.run_with_deepspeed()

0 commit comments

Comments
 (0)