Skip to content

Commit 65bc43d

Browse files
committed
Add tests.
1 parent 65c5e6f commit 65bc43d

File tree

3 files changed

+154
-72
lines changed

3 files changed

+154
-72
lines changed

tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py

Lines changed: 56 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,9 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
16
import json
2-
import os
3-
import sys
47
import unittest
58
import zipfile
69
from unittest import mock
@@ -14,7 +17,6 @@
1417
)
1518
from ads.opctl.distributed.common import cluster_config_helper as cluster
1619
from ads.jobs.templates import driver_utils as utils
17-
from ads.jobs.templates import driver_pytorch as driver
1820

1921

2022
class PyTorchRuntimeHandlerTest(unittest.TestCase):
@@ -89,78 +91,60 @@ def test_translate_env(self):
8991
json.dumps({self.INPUT_SRC: self.INPUT_DST}),
9092
)
9193
self.assertNotIn(Handler.CONST_DEEPSPEED, envs)
94+
# Test deepspeed env var
95+
envs = Handler(DataScienceJob())._translate_env(
96+
self.init_runtime().with_command("train.py", use_deepspeed=True)
97+
)
98+
self.assertIn(Handler.CONST_DEEPSPEED, envs)
9299

93-
94-
class PyTorchRunnerTest(unittest.TestCase):
95-
TEST_IP = "10.0.0.1"
96-
TEST_HOST_IP = "10.0.0.100"
97-
TEST_HOST_OCID = "ocid_host"
98-
TEST_NODE_OCID = "ocid_node"
99-
100-
def init_torch_runner(self):
101-
with mock.patch(
102-
"ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library"
103-
), mock.patch("socket.gethostbyname") as GetHostIP, mock.patch(
104-
"ads.jobs.DataScienceJobRun.from_ocid"
105-
) as GetJobRun:
106-
GetHostIP.return_value = self.TEST_IP
107-
GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk")
108-
return driver.TorchRunner()
109-
110-
@mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID})
111-
def test_init_torch_runner_at_node(self):
112-
runner = self.init_torch_runner()
113-
self.assertEqual(runner.host_ocid, self.TEST_HOST_OCID)
114-
self.assertEqual(runner.host_ip, None)
115-
116-
@mock.patch.dict(os.environ, {driver.CONST_ENV_JOB_RUN_OCID: TEST_NODE_OCID})
117-
def test_init_torch_runner_at_host(self):
118-
runner = self.init_torch_runner()
119-
self.assertEqual(runner.host_ocid, self.TEST_NODE_OCID)
120-
self.assertEqual(runner.host_ip, self.TEST_IP)
121-
122-
@mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID})
123-
def test_wait_for_host_ip(self):
124-
with mock.patch("ads.jobs.DataScienceJobRun.logs") as get_logs:
125-
get_logs.return_value = [
126-
{"message": f"{driver.LOG_PREFIX_HOST_IP} {self.TEST_HOST_IP}"}
127-
]
128-
runner = self.init_torch_runner()
129-
self.assertEqual(runner.host_ip, None)
130-
runner.wait_for_host_ip_address()
131-
self.assertEqual(runner.host_ip, self.TEST_HOST_IP)
132-
133-
@mock.patch.dict(
134-
os.environ, {driver.CONST_ENV_LAUNCH_CMD: "torchrun train.py --data abc"}
135-
)
136-
def test_launch_cmd(self):
137-
runner = self.init_torch_runner()
138-
self.assertTrue(runner.launch_cmd_contains("data"))
139-
self.assertFalse(runner.launch_cmd_contains("data1"))
100+
@mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.create")
101+
def test_extract_env(self, *args):
102+
"""Tests extracting YAML specs from environment variables."""
103+
job = DataScienceJob().create(self.init_runtime())
104+
spec = Handler(job)._extract_envs(job.dsc_job)
140105
self.assertEqual(
141-
runner.prepare_cmd(prefix="A=1"), "A=1 torchrun train.py --data abc"
106+
spec,
107+
{
108+
"conda": {"type": "service", "slug": "pytorch110_p38_gpu_v1"},
109+
"command": "torchrun distributed/minGPT-ddp/mingpt/main.py data_config.path=data/input.txt",
110+
"replicas": 2,
111+
"git": {
112+
"url": "https://github.com/pytorch/examples.git",
113+
"commit": "d91085d2181bf6342ac7dafbeee6fc0a1f64dcec",
114+
},
115+
"inputs": {"oci://bucket@namespace/path/to/input": "data/input.txt"},
116+
"dependencies": {
117+
"pipPackages": '"package>1.0"',
118+
"pipRequirements": "distributed/minGPT-ddp/requirements.txt",
119+
},
120+
},
142121
)
143122

144-
@mock.patch.dict(os.environ, {Handler.CONST_CODE_ENTRYPOINT: "train.py"})
145-
@mock.patch.object(sys, "argv", ["python", "hello", "--data", "abc"])
146-
def test_prepare_cmd_with_entrypoint_args(self):
147-
runner = self.init_torch_runner()
123+
@mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.create")
124+
@mock.patch("ads.jobs.builders.infrastructure.dsc_job.DSCJob.run")
125+
def test_create_job_runs(self, patched_run, *args):
126+
test_ocid = "ocid-test"
127+
patched_run.return_value = DataScienceJobRun(id=test_ocid)
128+
job = DataScienceJob().create(self.init_runtime())
129+
runtime = self.init_runtime()
130+
main_run = runtime.run(job.dsc_job)
131+
self.assertIsInstance(main_run, DataScienceJobRun)
132+
self.assertEqual(main_run.id, test_ocid)
133+
kwarg_list = [call_args.kwargs for call_args in patched_run.call_args_list]
148134
self.assertEqual(
149-
runner.prepare_cmd(launch_args=["--key", "val"], prefix="A=1"),
150-
"A=1 torchrun --key val train.py hello --data abc",
135+
kwarg_list,
136+
[
137+
{
138+
"display_name": "None-0",
139+
"environment_variables": {"RANK": "0", "WORLD_SIZE": 2},
140+
},
141+
{
142+
"display_name": "None-1",
143+
"environment_variables": {
144+
"RANK": "1",
145+
"WORLD_SIZE": 2,
146+
"MAIN_JOB_RUN_OCID": test_ocid,
147+
},
148+
},
149+
],
151150
)
152-
153-
154-
class LazyEvaluateTest(unittest.TestCase):
155-
def test_lazy_evaluation(self):
156-
def func(a, b):
157-
return a + b
158-
159-
def func_with_error():
160-
raise ValueError()
161-
162-
lazy_val = driver.LazyEvaluate(func, 1, 1)
163-
self.assertEqual(str(lazy_val), "2")
164-
165-
lazy_val = driver.LazyEvaluate(func_with_error)
166-
self.assertEqual(str(lazy_val), "ERROR: ")
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
Lines changed: 94 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,94 @@
1+
#!/usr/bin/env python
2+
3+
# Copyright (c) 2023 Oracle and/or its affiliates.
4+
# Licensed under the Universal Permissive License v 1.0 as shown at https://oss.oracle.com/licenses/upl/
5+
import os
6+
import sys
7+
import unittest
8+
from unittest import mock
9+
from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun
10+
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
11+
PyTorchDistributedRuntimeHandler as Handler,
12+
)
13+
from ads.jobs.builders.runtimes.pytorch_runtime import (
14+
PyTorchDistributedArtifact,
15+
GitPythonArtifact,
16+
)
17+
from ads.opctl.distributed.common import cluster_config_helper as cluster
18+
from ads.jobs.templates import driver_utils as utils
19+
from ads.jobs.templates import driver_pytorch as driver
20+
21+
22+
class PyTorchRunnerTest(unittest.TestCase):
23+
TEST_IP = "10.0.0.1"
24+
TEST_HOST_IP = "10.0.0.100"
25+
TEST_HOST_OCID = "ocid_host"
26+
TEST_NODE_OCID = "ocid_node"
27+
28+
def init_torch_runner(self):
29+
with mock.patch(
30+
"ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library"
31+
), mock.patch("socket.gethostbyname") as GetHostIP, mock.patch(
32+
"ads.jobs.DataScienceJobRun.from_ocid"
33+
) as GetJobRun:
34+
GetHostIP.return_value = self.TEST_IP
35+
GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk")
36+
return driver.TorchRunner()
37+
38+
@mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID})
39+
def test_init_torch_runner_at_node(self):
40+
runner = self.init_torch_runner()
41+
self.assertEqual(runner.host_ocid, self.TEST_HOST_OCID)
42+
self.assertEqual(runner.host_ip, None)
43+
44+
@mock.patch.dict(os.environ, {driver.CONST_ENV_JOB_RUN_OCID: TEST_NODE_OCID})
45+
def test_init_torch_runner_at_host(self):
46+
runner = self.init_torch_runner()
47+
self.assertEqual(runner.host_ocid, self.TEST_NODE_OCID)
48+
self.assertEqual(runner.host_ip, self.TEST_IP)
49+
50+
@mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID})
51+
def test_wait_for_host_ip(self):
52+
with mock.patch("ads.jobs.DataScienceJobRun.logs") as get_logs:
53+
get_logs.return_value = [
54+
{"message": f"{driver.LOG_PREFIX_HOST_IP} {self.TEST_HOST_IP}"}
55+
]
56+
runner = self.init_torch_runner()
57+
self.assertEqual(runner.host_ip, None)
58+
runner.wait_for_host_ip_address()
59+
self.assertEqual(runner.host_ip, self.TEST_HOST_IP)
60+
61+
@mock.patch.dict(
62+
os.environ, {driver.CONST_ENV_LAUNCH_CMD: "torchrun train.py --data abc"}
63+
)
64+
def test_launch_cmd(self):
65+
runner = self.init_torch_runner()
66+
self.assertTrue(runner.launch_cmd_contains("data"))
67+
self.assertFalse(runner.launch_cmd_contains("data1"))
68+
self.assertEqual(
69+
runner.prepare_cmd(prefix="A=1"), "A=1 torchrun train.py --data abc"
70+
)
71+
72+
@mock.patch.dict(os.environ, {Handler.CONST_CODE_ENTRYPOINT: "train.py"})
73+
@mock.patch.object(sys, "argv", ["python", "hello", "--data", "abc"])
74+
def test_prepare_cmd_with_entrypoint_args(self):
75+
runner = self.init_torch_runner()
76+
self.assertEqual(
77+
runner.prepare_cmd(launch_args=["--key", "val"], prefix="A=1"),
78+
"A=1 torchrun --key val train.py hello --data abc",
79+
)
80+
81+
82+
class LazyEvaluateTest(unittest.TestCase):
83+
def test_lazy_evaluation(self):
84+
def func(a, b):
85+
return a + b
86+
87+
def func_with_error():
88+
raise ValueError()
89+
90+
lazy_val = driver.LazyEvaluate(func, 1, 1)
91+
self.assertEqual(str(lazy_val), "2")
92+
93+
lazy_val = driver.LazyEvaluate(func_with_error)
94+
self.assertEqual(str(lazy_val), "ERROR: ")

0 commit comments

Comments
 (0)