Skip to content

Commit b09f843

Browse files
committed
Setup SSH for deepspeed only for multi-node.
1 parent 7ea5ef5 commit b09f843

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

ads/jobs/templates/driver_pytorch.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
2323
PythonRuntimeHandler,
2424
)
25-
from ads.jobs.templates import driver_utils
2625
from ads.opctl.distributed.common import cluster_config_helper
2726

2827
try:
@@ -458,6 +457,7 @@ def test_ssh_connection(self, host):
458457
logger.debug("SSH connection to %s - FAILED", host)
459458

460459
def touch_file(self, filename):
460+
"""Creates an empty file with specific name on all the worker nodes."""
461461
for node_ip in self.node_ip_list:
462462
logger.debug("Sending stop file to %s", node_ip)
463463
self.run_command(
@@ -499,18 +499,19 @@ def run_deepspeed_host(self, launch_args=None):
499499
For "deepspeed": --hostfile
500500
For "accelerate launch": --deepspeed_hostfile
501501
"""
502-
self.generate_key_pair().generate_hostfile()
503-
self.save_deepspeed_env()
504-
# Wait for nodes to be ready
505-
for run in self.node_runs:
506-
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
502+
if self.node_count > 1:
503+
self.generate_key_pair().generate_hostfile()
504+
self.save_deepspeed_env()
505+
# Wait for nodes to be ready
506+
for run in self.node_runs:
507+
self.wait_for_log(run, LOG_PREFIX_PUBLIC_KEY)
507508

508-
for node_ip in self.node_ip_list:
509-
self.run_command(
510-
f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts",
511-
level=logging.DEBUG,
512-
check=True,
513-
)
509+
for node_ip in self.node_ip_list:
510+
self.run_command(
511+
f"ssh-keyscan -H {node_ip} >> {SSH_DIR}/known_hosts",
512+
level=logging.DEBUG,
513+
check=True,
514+
)
514515

515516
cmd = self.prepare_cmd(launch_args)
516517
# For DeepSpeed, we only need to run the cmd on the host
@@ -533,7 +534,11 @@ def run_deepspeed_worker(self):
533534
logger.error("There is an error in the host job run.")
534535
sys.exit(1)
535536
# Stop the node if the host job run is CANCELLED or in unexpected state.
536-
self.host_job_run.sync()
537+
try:
538+
self.host_job_run.sync()
539+
except oci.exceptions.TransientServiceError:
540+
# Ignore the transient error and try again next time.
541+
continue
537542
if self.host_job_run.status not in [
538543
"ACCEPTED",
539544
"IN_PROGRESS",
@@ -548,7 +553,8 @@ def run_deepspeed_worker(self):
548553

549554
def run(self):
550555
if self.is_host:
551-
launch_args = [f"--hostfile={self.HOST_FILE}"]
556+
if self.node_count > 1:
557+
launch_args = [f"--hostfile={self.HOST_FILE}"]
552558
self.run_deepspeed_host(launch_args)
553559
else:
554560
self.run_deepspeed_worker()
@@ -601,8 +607,6 @@ def accelerate_args(self):
601607
return args
602608

603609
def run_with_torchrun(self):
604-
self.main_process_ip = self.host_ip
605-
606610
launch_args = self.accelerate_args()
607611
for arg in self.TORCHRUN_ARGS:
608612
if not self.launch_cmd_contains(arg):
@@ -613,12 +617,14 @@ def run_with_torchrun(self):
613617
def run_with_deepspeed(self):
614618
if self.is_host:
615619
launch_args = self.accelerate_args()
616-
launch_args.append(f"--deepspeed_hostfile={self.HOST_FILE}")
620+
if self.num_machines > 1:
621+
launch_args.append(f"--deepspeed_hostfile={self.HOST_FILE}")
617622
self.run_deepspeed_host(launch_args)
618623
else:
619624
self.run_deepspeed_worker()
620625

621626
def run(self):
627+
self.main_process_ip = self.host_ip
622628
# Check if any default argument is provided by the user
623629
for arg in self.DEFAULT_ARGS:
624630
if self.launch_cmd_contains(arg):

0 commit comments

Comments
 (0)