Skip to content

Commit ad845c7

Browse files
committed
Add torchrun test.
1 parent c769c7a commit ad845c7

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

tests/unitary/with_extras/jobs/test_pytorch_ddp.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,25 @@ def test_prepare_cmd_with_entrypoint_args(self):
7878
"A=1 torchrun --key val train.py hello --data abc",
7979
)
8080

81+
@mock.patch.dict(
82+
os.environ, {driver.CONST_ENV_LAUNCH_CMD: "torchrun train.py --data abc"}
83+
)
84+
@mock.patch("ads.jobs.templates.driver_utils.JobRunner.run_command")
85+
def test_run_torchrun(self, run_command):
86+
runner = self.init_torch_runner()
87+
runner.run()
88+
cmd = run_command.call_args.args[0]
89+
self.assertTrue(cmd.startswith("LD_PRELOAD="))
90+
self.assertTrue(
91+
cmd.endswith(
92+
"libhostname.so.1 OCI__HOSTNAME=10.0.0.1 "
93+
"torchrun --nnode=1 --nproc_per_node=1 "
94+
"--rdzv_backend=c10d --rdzv_endpoint=10.0.0.1:29400 --rdzv_conf=read_timeout=600 "
95+
"train.py --data abc"
96+
),
97+
cmd,
98+
)
99+
81100

82101
class LazyEvaluateTest(unittest.TestCase):
83102
def test_lazy_evaluation(self):

0 commit comments

Comments
 (0)