@@ -117,6 +117,40 @@ def test_run_command(self):
117117 self .assertEqual (runner .run_command ("pwd" , runner .conda_prefix , check = True ), 0 )
118118
119119
120+ class DeepSpeedRunnerTest (unittest .TestCase ):
121+ TEST_IP = "10.0.0.1"
122+
123+ def init_runner (self ):
124+ with mock .patch ("socket.gethostbyname" ) as GetHostIP , mock .patch (
125+ "ads.jobs.DataScienceJobRun.from_ocid"
126+ ) as GetJobRun , mock .patch (
127+ "ads.jobs.templates.driver_utils.JobRunner.run_command"
128+ ):
129+ GetHostIP .return_value = self .TEST_IP
130+ GetJobRun .return_value = DataScienceJobRun (id = "ocid.abcdefghijk" )
131+ return driver .DeepSpeedRunner ()
132+
133+ @mock .patch .dict (
134+ os .environ , {driver .CONST_ENV_LAUNCH_CMD : "deepspeed train.py --data abc" }
135+ )
136+ @mock .patch ("ads.jobs.templates.driver_utils.JobRunner.run_command" )
137+ @mock .patch ("ads.jobs.templates.driver_pytorch.Runner.time_cmd" )
138+ def test_run_single_node (self , time_cmd , * args ):
139+ runner = self .init_runner ()
140+ runner .run ()
141+ self .assertEqual (time_cmd .call_args .args [0 ], "deepspeed train.py --data abc" )
142+
143+ @mock .patch ("ads.jobs.templates.driver_utils.JobRunner.run_command" )
144+ def test_touch_file (self , run_command ):
145+ runner = self .init_runner ()
146+ runner .node_ip_list = ["10.0.0.2" , "10.0.0.3" ]
147+ runner .touch_file ("stop" )
148+ commasnds = [call_args .args [0 ] for call_args in run_command .call_args_list ]
149+ self .assertEqual (
150+ commasnds , ["ssh -v 10.0.0.2 'touch stop'" , "ssh -v 10.0.0.3 'touch stop'" ]
151+ )
152+
153+
120154class AccelerateRunnerTest (unittest .TestCase ):
121155 TEST_IP = "10.0.0.1"
122156
0 commit comments