|
| 1 | +#!/usr/bin/env python |
| 2 | + |
| 3 | +# Copyright (c) 2023 Oracle and/or its affiliates. |
| 4 | +# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/ |
| 5 | + |
1 | 6 | import json |
2 | | -import os |
3 | | -import sys |
4 | 7 | import unittest |
5 | 8 | import zipfile |
6 | 9 | from unittest import mock |
|
14 | 17 | ) |
15 | 18 | from ads.opctl.distributed.common import cluster_config_helper as cluster |
16 | 19 | from ads.jobs.templates import driver_utils as utils |
17 | | -from ads.jobs.templates import driver_pytorch as driver |
18 | 20 |
|
19 | 21 |
|
20 | 22 | class PyTorchRuntimeHandlerTest(unittest.TestCase): |
@@ -89,78 +91,60 @@ def test_translate_env(self): |
89 | 91 | json.dumps({self.INPUT_SRC: self.INPUT_DST}), |
90 | 92 | ) |
91 | 93 | self.assertNotIn(Handler.CONST_DEEPSPEED, envs) |
| 94 | + # Test deepspeed env var |
| 95 | + envs = Handler(DataScienceJob())._translate_env( |
| 96 | + self.init_runtime().with_command("train.py", use_deepspeed=True) |
| 97 | + ) |
| 98 | + self.assertIn(Handler.CONST_DEEPSPEED, envs) |
92 | 99 |
|
93 | | - |
94 | | -class PyTorchRunnerTest(unittest.TestCase): |
95 | | - TEST_IP = "10.0.0.1" |
96 | | - TEST_HOST_IP = "10.0.0.100" |
97 | | - TEST_HOST_OCID = "ocid_host" |
98 | | - TEST_NODE_OCID = "ocid_node" |
99 | | - |
100 | | - def init_torch_runner(self): |
101 | | - with mock.patch( |
102 | | - "ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library" |
103 | | - ), mock.patch("socket.gethostbyname") as GetHostIP, mock.patch( |
104 | | - "ads.jobs.DataScienceJobRun.from_ocid" |
105 | | - ) as GetJobRun: |
106 | | - GetHostIP.return_value = self.TEST_IP |
107 | | - GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk") |
108 | | - return driver.TorchRunner() |
109 | | - |
110 | | - @mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID}) |
111 | | - def test_init_torch_runner_at_node(self): |
112 | | - runner = self.init_torch_runner() |
113 | | - self.assertEqual(runner.host_ocid, self.TEST_HOST_OCID) |
114 | | - self.assertEqual(runner.host_ip, None) |
115 | | - |
116 | | - @mock.patch.dict(os.environ, {driver.CONST_ENV_JOB_RUN_OCID: TEST_NODE_OCID}) |
117 | | - def test_init_torch_runner_at_host(self): |
118 | | - runner = self.init_torch_runner() |
119 | | - self.assertEqual(runner.host_ocid, self.TEST_NODE_OCID) |
120 | | - self.assertEqual(runner.host_ip, self.TEST_IP) |
121 | | - |
122 | | - @mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID}) |
123 | | - def test_wait_for_host_ip(self): |
124 | | - with mock.patch("ads.jobs.DataScienceJobRun.logs") as get_logs: |
125 | | - get_logs.return_value = [ |
126 | | - {"message": f"{driver.LOG_PREFIX_HOST_IP} {self.TEST_HOST_IP}"} |
127 | | - ] |
128 | | - runner = self.init_torch_runner() |
129 | | - self.assertEqual(runner.host_ip, None) |
130 | | - runner.wait_for_host_ip_address() |
131 | | - self.assertEqual(runner.host_ip, self.TEST_HOST_IP) |
132 | | - |
133 | | - @mock.patch.dict( |
134 | | - os.environ, {driver.CONST_ENV_LAUNCH_CMD: "torchrun train.py --data abc"} |
135 | | - ) |
136 | | - def test_launch_cmd(self): |
137 | | - runner = self.init_torch_runner() |
138 | | - self.assertTrue(runner.launch_cmd_contains("data")) |
139 | | - self.assertFalse(runner.launch_cmd_contains("data1")) |
| 100 | + @mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.create") |
| 101 | + def test_extract_env(self, *args): |
| 102 | + """Tests extracting YAML specs from environment variables.""" |
| 103 | + job = DataScienceJob().create(self.init_runtime()) |
| 104 | + spec = Handler(job)._extract_envs(job.dsc_job) |
140 | 105 | self.assertEqual( |
141 | | - runner.prepare_cmd(prefix="A=1"), "A=1 torchrun train.py --data abc" |
| 106 | + spec, |
| 107 | + { |
| 108 | + "conda": {"type": "service", "slug": "pytorch110_p38_gpu_v1"}, |
| 109 | + "command": "torchrun distributed/minGPT-ddp/mingpt/main.py data_config.path=data/input.txt", |
| 110 | + "replicas": 2, |
| 111 | + "git": { |
| 112 | + "url": "https://github.com/pytorch/examples.git", |
| 113 | + "commit": "d91085d2181bf6342ac7dafbeee6fc0a1f64dcec", |
| 114 | + }, |
| 115 | + "inputs": {"oci://bucket@namespace/path/to/input": "data/input.txt"}, |
| 116 | + "dependencies": { |
| 117 | + "pipPackages": '"package>1.0"', |
| 118 | + "pipRequirements": "distributed/minGPT-ddp/requirements.txt", |
| 119 | + }, |
| 120 | + }, |
142 | 121 | ) |
143 | 122 |
|
144 | | - @mock.patch.dict(os.environ, {Handler.CONST_CODE_ENTRYPOINT: "train.py"}) |
145 | | - @mock.patch.object(sys, "argv", ["python", "hello", "--data", "abc"]) |
146 | | - def test_prepare_cmd_with_entrypoint_args(self): |
147 | | - runner = self.init_torch_runner() |
| 123 | + @mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.create") |
| 124 | + @mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.run") |
| 125 | + def test_create_job_runs(self, patched_run, *args): |
| 126 | + test_ocid = "ocid-test" |
| 127 | + patched_run.return_value = DataScienceJobRun(id=test_ocid) |
| 128 | + job = DataScienceJob().create(self.init_runtime()) |
| 129 | + runtime = self.init_runtime() |
| 130 | + main_run = runtime.run(job.dsc_job) |
| 131 | + self.assertIsInstance(main_run, DataScienceJobRun) |
| 132 | + self.assertEqual(main_run.id, test_ocid) |
| 133 | + kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list] |
148 | 134 | self.assertEqual( |
149 | | - runner.prepare_cmd(launch_args=["--key", "val"], prefix="A=1"), |
150 | | - "A=1 torchrun --key val train.py hello --data abc", |
| 135 | + kwarg_list, |
| 136 | + [ |
| 137 | + { |
| 138 | + "display_name": "None-0", |
| 139 | + "environment_variables": {"RANK": "0", "WORLD_SIZE": 2}, |
| 140 | + }, |
| 141 | + { |
| 142 | + "display_name": "None-1", |
| 143 | + "environment_variables": { |
| 144 | + "RANK": "1", |
| 145 | + "WORLD_SIZE": 2, |
| 146 | + "MAIN_JOB_RUN_OCID": test_ocid, |
| 147 | + }, |
| 148 | + }, |
| 149 | + ], |
151 | 150 | ) |
152 | | - |
153 | | - |
154 | | -class LazyEvaluateTest(unittest.TestCase): |
155 | | - def test_lazy_evaluation(self): |
156 | | - def func(a, b): |
157 | | - return a + b |
158 | | - |
159 | | - def func_with_error(): |
160 | | - raise ValueError() |
161 | | - |
162 | | - lazy_val = driver.LazyEvaluate(func, 1, 1) |
163 | | - self.assertEqual(str(lazy_val), "2") |
164 | | - |
165 | | - lazy_val = driver.LazyEvaluate(func_with_error) |
166 | | - self.assertEqual(str(lazy_val), "ERROR: ") |
0 commit comments