|
6 | 6 | import sys |
7 | 7 | import unittest |
8 | 8 | from unittest import mock |
9 | | -from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun |
| 9 | +from ads.jobs import DataScienceJobRun |
10 | 10 | from ads.jobs.builders.infrastructure.dsc_job_runtime import ( |
11 | 11 | PyTorchDistributedRuntimeHandler as Handler, |
12 | 12 | ) |
13 | | -from ads.jobs.builders.runtimes.pytorch_runtime import ( |
14 | | - PyTorchDistributedArtifact, |
15 | | - GitPythonArtifact, |
16 | | -) |
17 | | -from ads.opctl.distributed.common import cluster_config_helper as cluster |
18 | 13 | from ads.jobs.templates import driver_utils as utils |
19 | 14 | from ads.jobs.templates import driver_pytorch as driver |
20 | 15 |
|
@@ -97,6 +92,30 @@ def test_run_torchrun(self, run_command): |
97 | 92 | cmd, |
98 | 93 | ) |
99 | 94 |
|
| 95 | + @mock.patch.dict( |
| 96 | + os.environ, |
| 97 | + { |
| 98 | + utils.CONST_ENV_PIP_PKG: "abc==1.0", |
| 99 | + utils.CONST_ENV_PIP_REQ: "abc/requirements.txt", |
| 100 | + }, |
| 101 | + ) |
| 102 | + @mock.patch("ads.jobs.templates.driver_utils.JobRunner.run_command") |
| 103 | + def test_install_deps(self, run_command): |
| 104 | + runner = self.init_torch_runner() |
| 105 | + runner.install_dependencies() |
| 106 | + cmd_list = [call_args.args[0] for call_args in run_command.call_args_list] |
| 107 | + self.assertEqual( |
| 108 | + cmd_list, |
| 109 | + [ |
| 110 | + "pip install -r abc/requirements.txt", |
| 111 | + "pip install abc==1.0", |
| 112 | + ], |
| 113 | + ) |
| 114 | + |
| 115 | + def test_run_command(self): |
| 116 | + runner = self.init_torch_runner() |
| 117 | + self.assertEqual(runner.run_command("pwd", runner.conda_prefix, check=True), 0) |
| 118 | + |
100 | 119 |
|
101 | 120 | class LazyEvaluateTest(unittest.TestCase): |
102 | 121 | def test_lazy_evaluation(self): |
|
0 commit comments