|
1 | 1 | import json |
| 2 | +import time |
| 3 | +import traceback |
2 | 4 | from ads.jobs.builders.runtimes.artifact import PythonArtifact, GitPythonArtifact |
3 | 5 | from ads.jobs.builders.runtimes.python_runtime import ( |
4 | 6 | PythonRuntime, |
@@ -192,28 +194,38 @@ def use_deepspeed(self): |
192 | 194 | return False |
193 | 195 |
|
194 | 196 | def run(self, dsc_job, **kwargs): |
195 | | - """Starts the job runs |
196 | | - """ |
| 197 | + """Starts the job runs""" |
197 | 198 | replicas = self.replica if self.replica else 1 |
198 | 199 | main_run = None |
199 | | - for i in range(replicas): |
200 | | - replica_kwargs = kwargs.copy() |
201 | | - envs = replica_kwargs.get("environment_variables") |
202 | | - if not envs: |
203 | | - envs = {} |
204 | | - # Huggingface accelerate requires machine rank |
205 | | - envs["OCI__NODE_RANK"] = str(i) |
206 | | - if main_run: |
207 | | - envs["MAIN_JOB_RUN_OCID"] = main_run.id |
208 | | - name = replica_kwargs.get("display_name") |
209 | | - if not name: |
210 | | - name = dsc_job.display_name |
211 | | - |
212 | | - replica_kwargs["display_name"] = f"{name}-{str(i)}" |
213 | | - replica_kwargs["environment_variables"] = envs |
214 | | - run = dsc_job.run(**replica_kwargs) |
215 | | - if i == 0: |
216 | | - main_run = run |
| 200 | + job_runs = [] |
| 201 | + try: |
| 202 | + for i in range(replicas): |
| 203 | + replica_kwargs = kwargs.copy() |
| 204 | + envs = replica_kwargs.get("environment_variables") |
| 205 | + if not envs: |
| 206 | + envs = {} |
| 207 | + # Huggingface accelerate requires machine rank |
| 208 | + envs["OCI__NODE_RANK"] = str(i) |
| 209 | + if main_run: |
| 210 | + envs["MAIN_JOB_RUN_OCID"] = main_run.id |
| 211 | + name = replica_kwargs.get("display_name") |
| 212 | + if not name: |
| 213 | + name = dsc_job.display_name |
| 214 | + |
| 215 | + replica_kwargs["display_name"] = f"{name}-{str(i)}" |
| 216 | + replica_kwargs["environment_variables"] = envs |
| 217 | + run = dsc_job.run(**replica_kwargs) |
| 218 | + job_runs.append(run) |
| 219 | + if i == 0: |
| 220 | + main_run = run |
| 221 | + except Exception: |
| 222 | + traceback.print_exc() |
| 223 | + # Wait a few second to avoid the job run being in a transient state. |
| 224 | + time.sleep(2) |
| 225 | + # If there is any error when creating the job runs |
| 226 | + # cancel all the job runs. |
| 227 | + for run in job_runs: |
| 228 | + run.cancel() |
217 | 229 | return main_run |
218 | 230 |
|
219 | 231 |
|
|
0 commit comments