2222from ads .jobs .builders .infrastructure .dsc_job_runtime import (
2323 PythonRuntimeHandler ,
2424)
25- from ads .jobs .templates import driver_utils
2625from ads .opctl .distributed .common import cluster_config_helper
2726
2827try :
@@ -458,6 +457,7 @@ def test_ssh_connection(self, host):
458457 logger .debug ("SSH connection to %s - FAILED" , host )
459458
460459 def touch_file (self , filename ):
460+ """Creates an empty file with specific name on all the worker nodes."""
461461 for node_ip in self .node_ip_list :
462462 logger .debug ("Sending stop file to %s" , node_ip )
463463 self .run_command (
@@ -499,18 +499,19 @@ def run_deepspeed_host(self, launch_args=None):
499499 For "deepspeed": --hostfile
500500 For "accelerate launch": --deepspeed_hostfile
501501 """
502- self .generate_key_pair ().generate_hostfile ()
503- self .save_deepspeed_env ()
504- # Wait for nodes to be ready
505- for run in self .node_runs :
506- self .wait_for_log (run , LOG_PREFIX_PUBLIC_KEY )
502+ if self .node_count > 1 :
503+ self .generate_key_pair ().generate_hostfile ()
504+ self .save_deepspeed_env ()
505+ # Wait for nodes to be ready
506+ for run in self .node_runs :
507+ self .wait_for_log (run , LOG_PREFIX_PUBLIC_KEY )
507508
508- for node_ip in self .node_ip_list :
509- self .run_command (
510- f"ssh-keyscan -H { node_ip } >> { SSH_DIR } /known_hosts" ,
511- level = logging .DEBUG ,
512- check = True ,
513- )
509+ for node_ip in self .node_ip_list :
510+ self .run_command (
511+ f"ssh-keyscan -H { node_ip } >> { SSH_DIR } /known_hosts" ,
512+ level = logging .DEBUG ,
513+ check = True ,
514+ )
514515
515516 cmd = self .prepare_cmd (launch_args )
516517 # For DeepSpeed, we only need to run the cmd on the host
@@ -533,7 +534,11 @@ def run_deepspeed_worker(self):
533534 logger .error ("There is an error in the host job run." )
534535 sys .exit (1 )
535536 # Stop the node if the host job run is CANCELLED or in unexpected state.
536- self .host_job_run .sync ()
537+ try :
538+ self .host_job_run .sync ()
539+ except oci .exceptions .TransientServiceError :
540+ # Ignore the transient error and try again next time.
541+ continue
537542 if self .host_job_run .status not in [
538543 "ACCEPTED" ,
539544 "IN_PROGRESS" ,
@@ -548,7 +553,8 @@ def run_deepspeed_worker(self):
548553
549554 def run (self ):
550555 if self .is_host :
551- launch_args = [f"--hostfile={ self .HOST_FILE } " ]
556+ if self .node_count > 1 :
557+ launch_args = [f"--hostfile={ self .HOST_FILE } " ]
552558 self .run_deepspeed_host (launch_args )
553559 else :
554560 self .run_deepspeed_worker ()
@@ -601,8 +607,6 @@ def accelerate_args(self):
601607 return args
602608
603609 def run_with_torchrun (self ):
604- self .main_process_ip = self .host_ip
605-
606610 launch_args = self .accelerate_args ()
607611 for arg in self .TORCHRUN_ARGS :
608612 if not self .launch_cmd_contains (arg ):
@@ -613,12 +617,14 @@ def run_with_torchrun(self):
613617 def run_with_deepspeed (self ):
614618 if self .is_host :
615619 launch_args = self .accelerate_args ()
616- launch_args .append (f"--deepspeed_hostfile={ self .HOST_FILE } " )
620+ if self .num_machines > 1 :
621+ launch_args .append (f"--deepspeed_hostfile={ self .HOST_FILE } " )
617622 self .run_deepspeed_host (launch_args )
618623 else :
619624 self .run_deepspeed_worker ()
620625
621626 def run (self ):
627+ self .main_process_ip = self .host_ip
622628 # Check if any default argument is provided by the user
623629 for arg in self .DEFAULT_ARGS :
624630 if self .launch_cmd_contains (arg ):
0 commit comments