Skip to content

Commit c8d07cd

Browse files
committed
Add test for deepspeed.
1 parent 4f7f635 commit c8d07cd

File tree

2 files changed

+36
-0
lines changed

2 files changed

+36
-0
lines changed

ads/jobs/templates/driver_pytorch.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -656,6 +656,8 @@ def run(self):
656656
if self.is_host:
657657
if self.node_count > 1:
658658
launch_args = [f"--hostfile={self.HOST_FILE}"]
659+
else:
660+
launch_args = []
659661
self.run_deepspeed_host(launch_args)
660662
else:
661663
self.run_deepspeed_worker()

tests/unitary/with_extras/jobs/test_pytorch_ddp.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -117,6 +117,40 @@ def test_run_command(self):
117117
self.assertEqual(runner.run_command("pwd", runner.conda_prefix, check=True), 0)
118118

119119

120+
class DeepSpeedRunnerTest(unittest.TestCase):
121+
TEST_IP = "10.0.0.1"
122+
123+
def init_runner(self):
124+
with mock.patch("socket.gethostbyname") as GetHostIP, mock.patch(
125+
"ads.jobs.DataScienceJobRun.from_ocid"
126+
) as GetJobRun, mock.patch(
127+
"ads.jobs.templates.driver_utils.JobRunner.run_command"
128+
):
129+
GetHostIP.return_value = self.TEST_IP
130+
GetJobRun.return_value = DataScienceJobRun(id="ocid.abcdefghijk")
131+
return driver.DeepSpeedRunner()
132+
133+
@mock.patch.dict(
134+
os.environ, {driver.CONST_ENV_LAUNCH_CMD: "deepspeed train.py --data abc"}
135+
)
136+
@mock.patch("ads.jobs.templates.driver_utils.JobRunner.run_command")
137+
@mock.patch("ads.jobs.templates.driver_pytorch.Runner.time_cmd")
138+
def test_run_single_node(self, time_cmd, *args):
139+
runner = self.init_runner()
140+
runner.run()
141+
self.assertEqual(time_cmd.call_args.args[0], "deepspeed train.py --data abc")
142+
143+
@mock.patch("ads.jobs.templates.driver_utils.JobRunner.run_command")
144+
def test_touch_file(self, run_command):
145+
runner = self.init_runner()
146+
runner.node_ip_list = ["10.0.0.2", "10.0.0.3"]
147+
runner.touch_file("stop")
148+
commasnds = [call_args.args[0] for call_args in run_command.call_args_list]
149+
self.assertEqual(
150+
commasnds, ["ssh -v 10.0.0.2 'touch stop'", "ssh -v 10.0.0.3 'touch stop'"]
151+
)
152+
153+
120154
class AccelerateRunnerTest(unittest.TestCase):
121155
TEST_IP = "10.0.0.1"
122156

0 commit comments

Comments
 (0)