Skip to content

Commit f7a2756

Browse files
committed
Add test for accelerate runner.
1 parent 1942e0f commit f7a2756

File tree

1 file changed

+60
-0
lines changed

1 file changed

+60
-0
lines changed

tests/unitary/with_extras/jobs/test_pytorch_ddp.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,66 @@ def test_run_command(self):
117117
self.assertEqual(runner.run_command("pwd", runner.conda_prefix, check=True), 0)
118118

119119

120+
class AccelerateRunnerTest(unittest.TestCase):
121+
TEST_IP = "10.0.0.1"
122+
123+
def init_runner(self):
124+
with mock.patch(
125+
"ads.jobs.templates.driver_pytorch.TorchRunner.build_c_library"
126+
), mock.patch("socket.gethostbyname") as GetHostIP, mock.patch(
127+
"ads.jobs.DataScienceJobRun.from_ocid"
128+
) as GetJobRun, mock.patch(
129+
"ads.jobs.templates.driver_utils.JobRunner.run_command"
130+
):
131+
GetHostIP.return_value = self.TEST_IP
132+
GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk")
133+
return driver.AccelerateRunner()
134+
135+
@mock.patch.dict(
136+
os.environ,
137+
{
138+
driver.CONST_ENV_DEEPSPEED: "1",
139+
driver.OCI__WORKER_COUNT: "1",
140+
driver.CONST_ENV_LAUNCH_CMD: "accelerate launch train.py --data abc",
141+
"RANK": "0",
142+
},
143+
)
144+
@mock.patch("ads.jobs.templates.driver_pytorch.DeepSpeedRunner.run_deepspeed_host")
145+
@mock.patch("ads.jobs.templates.driver_utils.JobRunner.run_command")
146+
@mock.patch("ads.jobs.templates.driver_pytorch.Runner.time_cmd")
147+
def test_run(self, time_cmd, run_command, run_deepspeed):
148+
run_command.return_value = 0
149+
150+
runner = self.init_runner()
151+
runner.run_with_torchrun()
152+
self.assertTrue(
153+
time_cmd.call_args.kwargs["cmd"].endswith(
154+
"libhostname.so.1 OCI__HOSTNAME=10.0.0.1 "
155+
"accelerate launch --num_processes 2 --num_machines 2 --machine_rank 0 --main_process_port 29400 "
156+
"train.py --data abc"
157+
),
158+
time_cmd.call_args.kwargs["cmd"],
159+
)
160+
161+
runner.run()
162+
self.assertEqual(
163+
run_deepspeed.call_args.args[0],
164+
[
165+
"--num_processes",
166+
"2",
167+
"--num_machines",
168+
"2",
169+
"--machine_rank",
170+
"0",
171+
"--main_process_ip",
172+
"10.0.0.1",
173+
"--main_process_port",
174+
"29400",
175+
"--deepspeed_hostfile=/home/datascience/hostfile",
176+
],
177+
)
178+
179+
120180
class LazyEvaluateTest(unittest.TestCase):
121181
def test_lazy_evaluation(self):
122182
def func(a, b):

0 commit comments

Comments
 (0)