11import json
2+ import os
3+ import sys
24import unittest
35import zipfile
4- from ads .jobs import PyTorchDistributedRuntime , DataScienceJob
6+ from unittest import mock
7+ from ads .jobs import PyTorchDistributedRuntime , DataScienceJob , DataScienceJobRun
58from ads .jobs .builders .infrastructure .dsc_job_runtime import (
69 PyTorchDistributedRuntimeHandler as Handler ,
710)
811from ads .jobs .builders .runtimes .pytorch_runtime import (
912 PyTorchDistributedArtifact ,
1013 GitPythonArtifact ,
1114)
12- from ads .opctl .distributed .common import cluster_config_helper as Cluster
13- from ads .jobs .templates import driver_utils as Driver
15+ from ads .opctl .distributed .common import cluster_config_helper as cluster
16+ from ads .jobs .templates import driver_utils as utils
17+ from ads .jobs .templates import driver_pytorch as driver
1418
1519
1620class PyTorchRuntimeHandlerTest (unittest .TestCase ):
@@ -26,6 +30,7 @@ class PyTorchRuntimeHandlerTest(unittest.TestCase):
2630 )
2731
2832 def init_runtime (self ):
33+ """Initializes a PyTorchDistributedRuntime for testing."""
2934 return (
3035 PyTorchDistributedRuntime ()
3136 .with_replica (self .REPLICAS )
@@ -43,6 +48,7 @@ def init_runtime(self):
4348 )
4449
4550 def test_translate_artifact (self ):
51+ """Tests preparing ADS driver scripts in job artifacts."""
4652 artifact = Handler (DataScienceJob ())._translate_artifact (self .init_runtime ())
4753 self .assertIsInstance (artifact , PyTorchDistributedArtifact )
4854 self .assertEqual (
@@ -65,6 +71,7 @@ def test_translate_artifact(self):
6571 self .assertIn (GitPythonArtifact .CONST_DRIVER_SCRIPT , file_list )
6672
6773 def test_translate_env (self ):
74+ """Tests setting up environment variables"""
6875 envs = Handler (DataScienceJob ())._translate_env (self .init_runtime ())
6976 self .assertIsInstance (envs , dict )
7077 self .assertEqual (envs [Handler .CONST_WORKER_COUNT ], str (self .REPLICAS - 1 ))
@@ -73,12 +80,87 @@ def test_translate_env(self):
7380 PyTorchDistributedArtifact .CONST_DRIVER_SCRIPT ,
7481 )
7582 self .assertEqual (envs [Handler .CONST_COMMAND ], self .TORCHRUN_CMD )
76- self .assertEqual (envs [Cluster .OCI__RUNTIME_URI ], self .TEST_REPO )
77- self .assertEqual (envs [Cluster .OCI__RUNTIME_GIT_COMMIT ], self .TEST_COMMIT )
78- self .assertEqual (envs [Driver .CONST_ENV_PIP_PKG ], self .PIP_PKG )
79- self .assertEqual (envs [Driver .CONST_ENV_PIP_REQ ], self .PIP_REQ )
83+ self .assertEqual (envs [cluster .OCI__RUNTIME_URI ], self .TEST_REPO )
84+ self .assertEqual (envs [cluster .OCI__RUNTIME_GIT_COMMIT ], self .TEST_COMMIT )
85+ self .assertEqual (envs [utils .CONST_ENV_PIP_PKG ], self .PIP_PKG )
86+ self .assertEqual (envs [utils .CONST_ENV_PIP_REQ ], self .PIP_REQ )
8087 self .assertEqual (
81- envs [Driver .CONST_ENV_INPUT_MAPPINGS ],
88+ envs [utils .CONST_ENV_INPUT_MAPPINGS ],
8289 json .dumps ({self .INPUT_SRC : self .INPUT_DST }),
8390 )
8491 self .assertNotIn (Handler .CONST_DEEPSPEED , envs )
92+
93+
94+ class PyTorchRunnerTest (unittest .TestCase ):
95+ TEST_IP = "10.0.0.1"
96+ TEST_HOST_IP = "10.0.0.100"
97+ TEST_HOST_OCID = "ocid_host"
98+ TEST_NODE_OCID = "ocid_node"
99+
100+ def init_torch_runner (self ):
101+ with mock .patch (
102+ "ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library"
103+ ), mock .patch ("socket.gethostbyname" ) as GetHostIP , mock .patch (
104+ "ads.jobs.DataScienceJobRun.from_ocid"
105+ ) as GetJobRun :
106+ GetHostIP .return_value = self .TEST_IP
107+ GetJobRun .return_value = DataScienceJobRun (id = "ocid.abcdefghijk" )
108+ return driver .TorchRunner ()
109+
110+ @mock .patch .dict (os .environ , {driver .CONST_ENV_HOST_JOB_RUN_OCID : TEST_HOST_OCID })
111+ def test_init_torch_runner_at_node (self ):
112+ runner = self .init_torch_runner ()
113+ self .assertEqual (runner .host_ocid , self .TEST_HOST_OCID )
114+ self .assertEqual (runner .host_ip , None )
115+
116+ @mock .patch .dict (os .environ , {driver .CONST_ENV_JOB_RUN_OCID : TEST_NODE_OCID })
117+ def test_init_torch_runner_at_host (self ):
118+ runner = self .init_torch_runner ()
119+ self .assertEqual (runner .host_ocid , self .TEST_NODE_OCID )
120+ self .assertEqual (runner .host_ip , self .TEST_IP )
121+
122+ @mock .patch .dict (os .environ , {driver .CONST_ENV_HOST_JOB_RUN_OCID : TEST_HOST_OCID })
123+ def test_wait_for_host_ip (self ):
124+ with mock .patch ("ads.jobs.DataScienceJobRun.logs" ) as get_logs :
125+ get_logs .return_value = [
126+ {"message" : f"{ driver .LOG_PREFIX_HOST_IP } { self .TEST_HOST_IP } " }
127+ ]
128+ runner = self .init_torch_runner ()
129+ self .assertEqual (runner .host_ip , None )
130+ runner .wait_for_host_ip_address ()
131+ self .assertEqual (runner .host_ip , self .TEST_HOST_IP )
132+
133+ @mock .patch .dict (
134+ os .environ , {driver .CONST_ENV_LAUNCH_CMD : "torchrun train.py --data abc" }
135+ )
136+ def test_launch_cmd (self ):
137+ runner = self .init_torch_runner ()
138+ self .assertTrue (runner .launch_cmd_contains ("data" ))
139+ self .assertFalse (runner .launch_cmd_contains ("data1" ))
140+ self .assertEqual (
141+ runner .prepare_cmd (prefix = "A=1" ), "A=1 torchrun train.py --data abc"
142+ )
143+
144+ @mock .patch .dict (os .environ , {Handler .CONST_CODE_ENTRYPOINT : "train.py" })
145+ @mock .patch .object (sys , "argv" , ["python" , "hello" , "--data" , "abc" ])
146+ def test_prepare_cmd_with_entrypoint_args (self ):
147+ runner = self .init_torch_runner ()
148+ self .assertEqual (
149+ runner .prepare_cmd (launch_args = ["--key" , "val" ], prefix = "A=1" ),
150+ "A=1 torchrun --key val train.py hello --data abc" ,
151+ )
152+
153+
154+ class LazyEvaluateTest (unittest .TestCase ):
155+ def test_lazy_evaluation (self ):
156+ def func (a , b ):
157+ return a + b
158+
159+ def func_with_error ():
160+ raise ValueError ()
161+
162+ lazy_val = driver .LazyEvaluate (func , 1 , 1 )
163+ self .assertEqual (str (lazy_val ), "2" )
164+
165+ lazy_val = driver .LazyEvaluate (func_with_error )
166+ self .assertEqual (str (lazy_val ), "ERROR: " )
0 commit comments