Skip to content

Commit a419b07

Browse files
committed
Cancel all job runs in distributed training if there is an error creating one of them.
1 parent ea73eca commit a419b07

File tree

1 file changed

+32
-20
lines changed

1 file changed

+32
-20
lines changed

ads/jobs/builders/runtimes/pytorch_runtime.py

Lines changed: 32 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
11
import json
2+
import time
3+
import traceback
24
from ads.jobs.builders.runtimes.artifact import PythonArtifact, GitPythonArtifact
35
from ads.jobs.builders.runtimes.python_runtime import (
46
PythonRuntime,
@@ -192,28 +194,38 @@ def use_deepspeed(self):
192194
return False
193195

194196
def run(self, dsc_job, **kwargs):
195-
"""Starts the job runs
196-
"""
197+
"""Starts the job runs"""
197198
replicas = self.replica if self.replica else 1
198199
main_run = None
199-
for i in range(replicas):
200-
replica_kwargs = kwargs.copy()
201-
envs = replica_kwargs.get("environment_variables")
202-
if not envs:
203-
envs = {}
204-
# Huggingface accelerate requires machine rank
205-
envs["OCI__NODE_RANK"] = str(i)
206-
if main_run:
207-
envs["MAIN_JOB_RUN_OCID"] = main_run.id
208-
name = replica_kwargs.get("display_name")
209-
if not name:
210-
name = dsc_job.display_name
211-
212-
replica_kwargs["display_name"] = f"{name}-{str(i)}"
213-
replica_kwargs["environment_variables"] = envs
214-
run = dsc_job.run(**replica_kwargs)
215-
if i == 0:
216-
main_run = run
200+
job_runs = []
201+
try:
202+
for i in range(replicas):
203+
replica_kwargs = kwargs.copy()
204+
envs = replica_kwargs.get("environment_variables")
205+
if not envs:
206+
envs = {}
207+
# Huggingface accelerate requires machine rank
208+
envs["OCI__NODE_RANK"] = str(i)
209+
if main_run:
210+
envs["MAIN_JOB_RUN_OCID"] = main_run.id
211+
name = replica_kwargs.get("display_name")
212+
if not name:
213+
name = dsc_job.display_name
214+
215+
replica_kwargs["display_name"] = f"{name}-{str(i)}"
216+
replica_kwargs["environment_variables"] = envs
217+
run = dsc_job.run(**replica_kwargs)
218+
job_runs.append(run)
219+
if i == 0:
220+
main_run = run
221+
except Exception:
222+
traceback.print_exc()
223+
# Wait a few second to avoid the job run being in a transient state.
224+
time.sleep(2)
225+
# If there is any error when creating the job runs
226+
# cancel all the job runs.
227+
for run in job_runs:
228+
run.cancel()
217229
return main_run
218230

219231

0 commit comments

Comments
 (0)