Skip to content

Commit 5dd2d9f

Browse files
committed
Update test_pytorch_ddp.py
1 parent ad845c7 commit 5dd2d9f

File tree

1 file changed

+25
-6
lines changed

1 file changed

+25
-6
lines changed

tests/unitary/with_extras/jobs/test_pytorch_ddp.py

Lines changed: 25 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -6,15 +6,10 @@
66
import sys
77
import unittest
88
from unittest import mock
9-
from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun
9+
from ads.jobs import DataScienceJobRun
1010
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
1111
PyTorchDistributedRuntimeHandler as Handler,
1212
)
13-
from ads.jobs.builders.runtimes.pytorch_runtime import (
14-
PyTorchDistributedArtifact,
15-
GitPythonArtifact,
16-
)
17-
from ads.opctl.distributed.common import cluster_config_helper as cluster
1813
from ads.jobs.templates import driver_utils as utils
1914
from ads.jobs.templates import driver_pytorch as driver
2015

@@ -97,6 +92,30 @@ def test_run_torchrun(self, run_command):
9792
cmd,
9893
)
9994

95+
@mock.patch.dict(
96+
os.environ,
97+
{
98+
utils.CONST_ENV_PIP_PKG: "abc==1.0",
99+
utils.CONST_ENV_PIP_REQ: "abc/requirements.txt",
100+
},
101+
)
102+
@mock.patch("ads.jobs.templates.driver_utils.JobRunner.run_command")
103+
def test_install_deps(self, run_command):
104+
runner = self.init_torch_runner()
105+
runner.install_dependencies()
106+
cmd_list = [call_args.args[0] for call_args in run_command.call_args_list]
107+
self.assertEqual(
108+
cmd_list,
109+
[
110+
"pip install -r abc/requirements.txt",
111+
"pip install abc==1.0",
112+
],
113+
)
114+
115+
def test_run_command(self):
116+
runner = self.init_torch_runner()
117+
self.assertEqual(runner.run_command("pwd", runner.conda_prefix, check=True), 0)
118+
100119

101120
class LazyEvaluateTest(unittest.TestCase):
102121
def test_lazy_evaluation(self):

0 commit comments

Comments
 (0)