Skip to content

Commit ba5659f

Browse files
committed
Update test.
1 parent 70e3ce1 commit ba5659f

File tree

1 file changed

+6
-5
lines changed

1 file changed

+6
-5
lines changed

tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,24 @@
11
#!/usr/bin/env python
22

3-
# Copyright (c) 2023 Oracle and/or its affiliates.
3+
# Copyright (c) 2025 Oracle and/or its affiliates.
44
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
55

66
import json
77
import os
88
import unittest
99
import zipfile
1010
from unittest import mock
11-
from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun
11+
12+
from ads.jobs import DataScienceJob, DataScienceJobRun, PyTorchDistributedRuntime
1213
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1314
PyTorchDistributedRuntimeHandler as Handler,
1415
)
1516
from ads.jobs.builders.runtimes.pytorch_runtime import (
16-
PyTorchDistributedArtifact,
1717
GitPythonArtifact,
18+
PyTorchDistributedArtifact,
1819
)
19-
from ads.opctl.distributed.common import cluster_config_helper as cluster
2020
from ads.jobs.templates import driver_utils as utils
21+
from ads.opctl.distributed.common import cluster_config_helper as cluster
2122

2223

2324
class PyTorchRuntimeHandlerTest(unittest.TestCase):
@@ -77,7 +78,7 @@ def test_translate_env(self):
7778
"""Tests setting up environment variables"""
7879
envs = Handler(DataScienceJob())._translate_env(self.init_runtime())
7980
self.assertIsInstance(envs, dict)
80-
self.assertEqual(envs[Handler.CONST_WORKER_COUNT], str(self.REPLICAS - 1))
81+
self.assertEqual(envs[Handler.CONST_NODE_COUNT], str(self.REPLICAS))
8182
self.assertEqual(
8283
envs[Handler.CONST_JOB_ENTRYPOINT],
8384
PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT,

0 commit comments

Comments
 (0)