|
1 | 1 | #!/usr/bin/env python |
2 | 2 |
|
3 | | -# Copyright (c) 2023 Oracle and/or its affiliates. |
| 3 | +# Copyright (c) 2025 Oracle and/or its affiliates. |
4 | 4 | # Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
5 | 5 |
|
6 | 6 | import json |
7 | 7 | import os |
8 | 8 | import unittest |
9 | 9 | import zipfile |
10 | 10 | from unittest import mock |
11 | | -from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun |
| 11 | + |
| 12 | +from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime |
12 | 13 | from ads.jobs.builders.infrastructure.dsc_job_runtime import ( |
13 | 14 | PyTorchDistributedRuntimeHandler as Handler, |
14 | 15 | ) |
15 | 16 | from ads.jobs.builders.runtimes.pytorch_runtime import ( |
16 | | - PyTorchDistributedArtifact, |
17 | 17 | GitPythonArtifact, |
| 18 | + PyTorchDistributedArtifact, |
18 | 19 | ) |
19 | | -from ads.opctl.distributed.common import cluster_config_helper as cluster |
20 | 20 | from ads.jobs.templates import driver_utils as utils |
| 21 | +from ads.opctl.distributed.common import cluster_config_helper as cluster |
21 | 22 |
|
22 | 23 |
|
23 | 24 | class PyTorchRuntimeHandlerTest(unittest.TestCase): |
@@ -77,7 +78,7 @@ def test_translate_env(self): |
77 | 78 | """Tests setting up environment variables""" |
78 | 79 | envs = Handler(DataScienceJob())._translate_env(self.init_runtime()) |
79 | 80 | self.assertIsInstance(envs, dict) |
80 | | - self.assertEqual(envs[Handler.CONST_WORKER_COUNT], str(self.REPLICAS - 1)) |
| 81 | + self.assertEqual(envs[Handler.CONST_NODE_COUNT], str(self.REPLICAS)) |
81 | 82 | self.assertEqual( |
82 | 83 | envs[Handler.CONST_JOB_ENTRYPOINT], |
83 | 84 | PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT, |
|
0 commit comments