Skip to content

Commit 65c5e6f

Browse files
committed
Update test_jobs_pytorch_ddp.py
1 parent 4d49e56 commit 65c5e6f

File tree

1 file changed

+90
-8
lines changed

1 file changed

+90
-8
lines changed

tests/unitary/default_setup/jobs/test_jobs_pytorch_ddp.py

Lines changed: 90 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,20 @@
11
import json
2+
import os
3+
import sys
24
import unittest
35
import zipfile
4-
from ads.jobs import PyTorchDistributedRuntime, DataScienceJob
6+
from unittest import mock
7+
from ads.jobs import PyTorchDistributedRuntime, DataScienceJob, DataScienceJobRun
58
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
69
PyTorchDistributedRuntimeHandler as Handler,
710
)
811
from ads.jobs.builders.runtimes.pytorch_runtime import (
912
PyTorchDistributedArtifact,
1013
GitPythonArtifact,
1114
)
12-
from ads.opctl.distributed.common import cluster_config_helper as Cluster
13-
from ads.jobs.templates import driver_utils as Driver
15+
from ads.opctl.distributed.common import cluster_config_helper as cluster
16+
from ads.jobs.templates import driver_utils as utils
17+
from ads.jobs.templates import driver_pytorch as driver
1418

1519

1620
class PyTorchRuntimeHandlerTest(unittest.TestCase):
@@ -26,6 +30,7 @@ class PyTorchRuntimeHandlerTest(unittest.TestCase):
2630
)
2731

2832
def init_runtime(self):
33+
"""Initializes a PyTorchDistributedRuntime for testing."""
2934
return (
3035
PyTorchDistributedRuntime()
3136
.with_replica(self.REPLICAS)
@@ -43,6 +48,7 @@ def init_runtime(self):
4348
)
4449

4550
def test_translate_artifact(self):
51+
"""Tests preparing ADS driver scripts in job artifacts."""
4652
artifact = Handler(DataScienceJob())._translate_artifact(self.init_runtime())
4753
self.assertIsInstance(artifact, PyTorchDistributedArtifact)
4854
self.assertEqual(
@@ -65,6 +71,7 @@ def test_translate_artifact(self):
6571
self.assertIn(GitPythonArtifact.CONST_DRIVER_SCRIPT, file_list)
6672

6773
def test_translate_env(self):
74+
"""Tests setting up environment variables"""
6875
envs = Handler(DataScienceJob())._translate_env(self.init_runtime())
6976
self.assertIsInstance(envs, dict)
7077
self.assertEqual(envs[Handler.CONST_WORKER_COUNT], str(self.REPLICAS - 1))
@@ -73,12 +80,87 @@ def test_translate_env(self):
7380
PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT,
7481
)
7582
self.assertEqual(envs[Handler.CONST_COMMAND], self.TORCHRUN_CMD)
76-
self.assertEqual(envs[Cluster.OCI__RUNTIME_URI], self.TEST_REPO)
77-
self.assertEqual(envs[Cluster.OCI__RUNTIME_GIT_COMMIT], self.TEST_COMMIT)
78-
self.assertEqual(envs[Driver.CONST_ENV_PIP_PKG], self.PIP_PKG)
79-
self.assertEqual(envs[Driver.CONST_ENV_PIP_REQ], self.PIP_REQ)
83+
self.assertEqual(envs[cluster.OCI__RUNTIME_URI], self.TEST_REPO)
84+
self.assertEqual(envs[cluster.OCI__RUNTIME_GIT_COMMIT], self.TEST_COMMIT)
85+
self.assertEqual(envs[utils.CONST_ENV_PIP_PKG], self.PIP_PKG)
86+
self.assertEqual(envs[utils.CONST_ENV_PIP_REQ], self.PIP_REQ)
8087
self.assertEqual(
81-
envs[Driver.CONST_ENV_INPUT_MAPPINGS],
88+
envs[utils.CONST_ENV_INPUT_MAPPINGS],
8289
json.dumps({self.INPUT_SRC: self.INPUT_DST}),
8390
)
8491
self.assertNotIn(Handler.CONST_DEEPSPEED, envs)
92+
93+
94+
class PyTorchRunnerTest(unittest.TestCase):
95+
TEST_IP = "10.0.0.1"
96+
TEST_HOST_IP = "10.0.0.100"
97+
TEST_HOST_OCID = "ocid_host"
98+
TEST_NODE_OCID = "ocid_node"
99+
100+
def init_torch_runner(self):
101+
with mock.patch(
102+
"ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library"
103+
), mock.patch("socket.gethostbyname") as GetHostIP, mock.patch(
104+
"ads.jobs.DataScienceJobRun.from_ocid"
105+
) as GetJobRun:
106+
GetHostIP.return_value = self.TEST_IP
107+
GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk")
108+
return driver.TorchRunner()
109+
110+
@mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID})
111+
def test_init_torch_runner_at_node(self):
112+
runner = self.init_torch_runner()
113+
self.assertEqual(runner.host_ocid, self.TEST_HOST_OCID)
114+
self.assertEqual(runner.host_ip, None)
115+
116+
@mock.patch.dict(os.environ, {driver.CONST_ENV_JOB_RUN_OCID: TEST_NODE_OCID})
117+
def test_init_torch_runner_at_host(self):
118+
runner = self.init_torch_runner()
119+
self.assertEqual(runner.host_ocid, self.TEST_NODE_OCID)
120+
self.assertEqual(runner.host_ip, self.TEST_IP)
121+
122+
@mock.patch.dict(os.environ, {driver.CONST_ENV_HOST_JOB_RUN_OCID: TEST_HOST_OCID})
123+
def test_wait_for_host_ip(self):
124+
with mock.patch("ads.jobs.DataScienceJobRun.logs") as get_logs:
125+
get_logs.return_value = [
126+
{"message": f"{driver.LOG_PREFIX_HOST_IP} {self.TEST_HOST_IP}"}
127+
]
128+
runner = self.init_torch_runner()
129+
self.assertEqual(runner.host_ip, None)
130+
runner.wait_for_host_ip_address()
131+
self.assertEqual(runner.host_ip, self.TEST_HOST_IP)
132+
133+
@mock.patch.dict(
134+
os.environ, {driver.CONST_ENV_LAUNCH_CMD: "torchrun train.py --data abc"}
135+
)
136+
def test_launch_cmd(self):
137+
runner = self.init_torch_runner()
138+
self.assertTrue(runner.launch_cmd_contains("data"))
139+
self.assertFalse(runner.launch_cmd_contains("data1"))
140+
self.assertEqual(
141+
runner.prepare_cmd(prefix="A=1"), "A=1 torchrun train.py --data abc"
142+
)
143+
144+
@mock.patch.dict(os.environ, {Handler.CONST_CODE_ENTRYPOINT: "train.py"})
145+
@mock.patch.object(sys, "argv", ["python", "hello", "--data", "abc"])
146+
def test_prepare_cmd_with_entrypoint_args(self):
147+
runner = self.init_torch_runner()
148+
self.assertEqual(
149+
runner.prepare_cmd(launch_args=["--key", "val"], prefix="A=1"),
150+
"A=1 torchrun --key val train.py hello --data abc",
151+
)
152+
153+
154+
class LazyEvaluateTest(unittest.TestCase):
155+
def test_lazy_evaluation(self):
156+
def func(a, b):
157+
return a + b
158+
159+
def func_with_error():
160+
raise ValueError()
161+
162+
lazy_val = driver.LazyEvaluate(func, 1, 1)
163+
self.assertEqual(str(lazy_val), "2")
164+
165+
lazy_val = driver.LazyEvaluate(func_with_error)
166+
self.assertEqual(str(lazy_val), "ERROR: ")

0 commit comments

Comments
 (0)