@@ -117,6 +117,66 @@ def test_run_command(self):
117117 self .assertEqual (runner .run_command ("pwd" , runner .conda_prefix , check = True ), 0 )
118118
119119
120+ class AccelerateRunnerTest (unittest .TestCase ):
121+ TEST_IP = "10.0.0.1"
122+
123+ def init_runner (self ):
124+ with mock .patch (
125+ "ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library"
126+ ), mock .patch ("socket.gethostbyname" ) as GetHostIP , mock .patch (
127+ "ads.jobs.DataScienceJobRun.from_ocid"
128+ ) as GetJobRun , mock .patch (
129+ "ads.jobs.templates.driver_utils.JobRunner.run_command"
130+ ):
131+ GetHostIP .return_value = self .TEST_IP
132+ GetJobRun .return_value = DataScienceJobRun (id = "ocid.abcdefghijk" )
133+ return driver .AccelerateRunner ()
134+
135+ @mock .patch .dict (
136+ os .environ ,
137+ {
138+ driver .CONST_ENV_DEEPSPEED : "1" ,
139+ driver .OCI__WORKER_COUNT : "1" ,
140+ driver .CONST_ENV_LAUNCH_CMD : "accelerate launch train.py --data abc" ,
141+ "RANK" : "0" ,
142+ },
143+ )
144+ @mock .patch ("ads.jobs.templates.driver_pytorch.DeepSpeedRunner.run_deepspeed_host" )
145+ @mock .patch ("ads.jobs.templates.driver_utils.JobRunner.run_command" )
146+ @mock .patch ("ads.jobs.templates.driver_pytorch.Runner.time_cmd" )
147+ def test_run (self , time_cmd , run_command , run_deepspeed ):
148+ run_command .return_value = 0
149+
150+ runner = self .init_runner ()
151+ runner .run_with_torchrun ()
152+ self .assertTrue (
153+ time_cmd .call_args .kwargs ["cmd" ].endswith (
154+ "libhostname.so.1 OCI__HOSTNAME=10.0.0.1 "
155+ "accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 "
156+ "train.py --data abc"
157+ ),
158+ time_cmd .call_args .kwargs ["cmd" ],
159+ )
160+
161+ runner .run ()
162+ self .assertEqual (
163+ run_deepspeed .call_args .args [0 ],
164+ [
165+ "--num_processes" ,
166+ "2" ,
167+ "--num_machines" ,
168+ "2" ,
169+ "--machine_rank" ,
170+ "0" ,
171+ "--main_process_ip" ,
172+ "10.0.0.1" ,
173+ "--main_process_port" ,
174+ "29400" ,
175+ "--deepspeed_hostfile=/home/datascience/hostfile" ,
176+ ],
177+ )
178+
179+
120180class LazyEvaluateTest (unittest .TestCase ):
121181 def test_lazy_evaluation (self ):
122182 def func (a , b ):
0 commit comments