Skip to content

Commit 1fac993

Browse files
committed
Add tests.
1 parent a419b07 commit 1fac993

File tree

1 file changed

+84
-0
lines changed

1 file changed

+84
-0
lines changed
Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
import json
2+
import unittest
3+
import zipfile
4+
from ads.jobs import PyTorchDistributedRuntime, DataScienceJob
5+
from ads.jobs.builders.infrastructure.dsc_job_runtime import (
6+
PyTorchDistributedRuntimeHandler as Handler,
7+
)
8+
from ads.jobs.builders.runtimes.pytorch_runtime import (
9+
PyTorchDistributedArtifact,
10+
GitPythonArtifact,
11+
)
12+
from ads.opctl.distributed.common import cluster_config_helper as Cluster
13+
from ads.jobs.templates import driver_utils as Driver
14+
15+
16+
class PyTorchRuntimeHandlerTest(unittest.TestCase):
17+
INPUT_SRC = "oci://bucket@namespace/path/to/input"
18+
INPUT_DST = "data/input.txt"
19+
TEST_REPO = "https://github.com/pytorch/examples.git"
20+
TEST_COMMIT = "d91085d2181bf6342ac7dafbeee6fc0a1f64dcec"
21+
REPLICAS = 2
22+
PIP_REQ = "distributed/minGPT-ddp/requirements.txt"
23+
PIP_PKG = '"package>1.0"'
24+
TORCHRUN_CMD = (
25+
"torchrun distributed/minGPT-ddp/mingpt/main.py data_config.path=data/input.txt"
26+
)
27+
28+
def init_runtime(self):
29+
return (
30+
PyTorchDistributedRuntime()
31+
.with_replica(self.REPLICAS)
32+
.with_service_conda("pytorch110_p38_gpu_v1")
33+
.with_git(
34+
self.TEST_REPO,
35+
commit=self.TEST_COMMIT,
36+
)
37+
.with_inputs({self.INPUT_SRC: self.INPUT_DST})
38+
.with_dependency(
39+
pip_req=self.PIP_REQ,
40+
pip_pkg=self.PIP_PKG,
41+
)
42+
.with_command(self.TORCHRUN_CMD)
43+
)
44+
45+
def test_translate_artifact(self):
46+
artifact = Handler(DataScienceJob())._translate_artifact(self.init_runtime())
47+
self.assertIsInstance(artifact, PyTorchDistributedArtifact)
48+
self.assertEqual(
49+
artifact.source,
50+
"",
51+
"Artifact source should be empty when using source code from Git.",
52+
)
53+
with artifact:
54+
self.assertTrue(
55+
artifact.path.endswith(
56+
PyTorchDistributedArtifact.DEFAULT_BASENAME + ".zip"
57+
)
58+
)
59+
file_list = zipfile.ZipFile(artifact.path).namelist()
60+
self.assertEqual(len(file_list), 5, f"Expected 5 files. Got: {file_list}")
61+
self.assertIn(PyTorchDistributedArtifact.CONST_DRIVER_UTILS, file_list)
62+
self.assertIn(PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT, file_list)
63+
self.assertIn(PyTorchDistributedArtifact.CONST_LIB_HOSTNAME, file_list)
64+
self.assertIn(PyTorchDistributedArtifact.CONST_OCI_METRICS, file_list)
65+
self.assertIn(GitPythonArtifact.CONST_DRIVER_SCRIPT, file_list)
66+
67+
def test_translate_env(self):
68+
envs = Handler(DataScienceJob())._translate_env(self.init_runtime())
69+
self.assertIsInstance(envs, dict)
70+
self.assertEqual(envs[Handler.CONST_WORKER_COUNT], str(self.REPLICAS - 1))
71+
self.assertEqual(
72+
envs[Handler.CONST_JOB_ENTRYPOINT],
73+
PyTorchDistributedArtifact.CONST_DRIVER_SCRIPT,
74+
)
75+
self.assertEqual(envs[Handler.CONST_COMMAND], self.TORCHRUN_CMD)
76+
self.assertEqual(envs[Cluster.OCI__RUNTIME_URI], self.TEST_REPO)
77+
self.assertEqual(envs[Cluster.OCI__RUNTIME_GIT_COMMIT], self.TEST_COMMIT)
78+
self.assertEqual(envs[Driver.CONST_ENV_PIP_PKG], self.PIP_PKG)
79+
self.assertEqual(envs[Driver.CONST_ENV_PIP_REQ], self.PIP_REQ)
80+
self.assertEqual(
81+
envs[Driver.CONST_ENV_INPUT_MAPPINGS],
82+
json.dumps({self.INPUT_SRC: self.INPUT_DST}),
83+
)
84+
self.assertNotIn(Handler.CONST_DEEPSPEED, envs)

0 commit comments

Comments
 (0)