Skip to content

Commit 6058747

Browse files
harikodaliSilv3S
authored andcommitted
add device generalization support for distributed tests (pytorch#165067)
## MOTIVATION To generalize Distributed test cases for non-CUDA devices ## CHANGES - Replaced hard coded device/backends with torch.accelerator.current_accelerator() and dist.get_default_backend_for_device - Use DistributedTestBase instead of MultiProcessTestCase to use common utilities - Remove instantiate_device_tests and make use of torch.accelerator.current_accelerator for test/distributed/test_c10d_object_collectives.py - fix deterministic context issue for non-cuda devices in test/distributed/optim/test_zero_redundancy_optimizer.py - use torch.accelerator.device_count() for multi-gpu check in torch/testing/_internal/distributed/_tensor/common_dtensor.py Pull Request resolved: pytorch#165067 Approved by: https://github.com/guangyey, https://github.com/albanD
1 parent d7c5131 commit 6058747

File tree

9 files changed

+102
-144
lines changed

9 files changed

+102
-144
lines changed

test/distributed/_composable/test_composability/test_2d_composability.py

Lines changed: 14 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,7 @@
6565

6666

6767
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
68+
curr_backend = dist.get_default_backend_for_device(device_type)
6869

6970

7071
class SimpleModel(nn.Module):
@@ -422,10 +423,10 @@ class TestFullyShard2DStateDict(DTensorTestBase):
422423
@property
423424
def backend(self):
424425
# need to specify gloo backend for testing cpu offload
425-
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
426+
return f"cpu:gloo,{device_type}:{curr_backend}"
426427

427-
@with_comms
428428
@skip_if_lt_x_gpu(4)
429+
@with_comms
429430
def test_fully_shard_tp_2d_set_full_state_dict(self):
430431
dummy_model = SimpleModel().to(device_type)
431432
mesh_2d = init_device_mesh(
@@ -514,8 +515,8 @@ def _check_module(self, m1, m2, check_grad=False):
514515
).to_local()
515516
self.assertEqual(param_m2, param_m1)
516517

517-
@with_comms
518518
@skip_if_lt_x_gpu(4)
519+
@with_comms
519520
def test_2d_ddp_integration_functionality(self) -> None:
520521
model, twod_model, dp_pg = self.init_model(self.device_type)
521522
optim = torch.optim.Adam(model.parameters(), lr=3e-5)
@@ -566,8 +567,8 @@ def _compare_params(self, m1, m2):
566567
p2 = p2.redistribute(p2.device_mesh, [Replicate()]).to_local()
567568
self.assertTrue(torch.allclose(p1, p2), f"{p1} vs {p2}")
568569

569-
@with_comms
570570
@skip_if_lt_x_gpu(4)
571+
@with_comms
571572
def test_2d_fsdp_state_enable_extension(self):
572573
mesh_2d = init_device_mesh(
573574
self.device_type, (2, self.world_size // 2), mesh_dim_names=("dp", "tp")
@@ -642,18 +643,18 @@ def _test_2d_e2e_training(
642643
# Ensure all params are still the same after optimizer update.
643644
self._compare_params(model, model_2d)
644645

645-
@with_comms
646646
@skip_if_lt_x_gpu(4)
647+
@with_comms
647648
def test_2d_e2e_training_default(self):
648649
self._test_2d_e2e_training()
649650

650-
@with_comms
651651
@skip_if_lt_x_gpu(4)
652+
@with_comms
652653
def test_2d_e2e_training_use_orig_params(self):
653654
self._test_2d_e2e_training(use_orig_params=True)
654655

655-
@with_comms
656656
@skip_if_lt_x_gpu(4)
657+
@with_comms
657658
def test_2d_e2e_training_not_use_orig_params(self):
658659
# TODO: need to revisit input_reshard API about why it failed multi-gpu tests.
659660
# self._test_2d_e2e_training(recompute_activation=True)
@@ -666,10 +667,10 @@ class TestNew2dParallelStateDict(DTensorTestBase):
666667
@property
667668
def backend(self):
668669
# need to specify gloo backend for testing cpu offload
669-
return "cpu:gloo,xpu:xccl" if TEST_XPU else "cpu:gloo,cuda:nccl"
670+
return f"cpu:gloo,{device_type}:{curr_backend}"
670671

671-
@with_comms
672672
@skip_if_lt_x_gpu(4)
673+
@with_comms
673674
def test_fsdp_2d_extension(self):
674675
"""
675676
Test whether _fsdp_extension from FSDPstate has been set correctly.
@@ -700,8 +701,8 @@ def test_fsdp_2d_extension(self):
700701
model_1d_fsdp_state = _get_module_fsdp_state(model_1d)
701702
self.assertEqual(model_1d_fsdp_state._fsdp_extension, None)
702703

703-
@with_comms
704704
@skip_if_lt_x_gpu(4)
705+
@with_comms
705706
@parametrize("is_even_sharded_model", [True, False])
706707
def test_2d_state_dict(self, is_even_sharded_model):
707708
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@@ -756,8 +757,8 @@ def test_2d_state_dict(self, is_even_sharded_model):
756757
torch.allclose(no_wrap_v, all_gather_two_d_v.to_local()), True
757758
)
758759

759-
@with_comms
760760
@skip_if_lt_x_gpu(4)
761+
@with_comms
761762
@parametrize("is_even_sharded_model", [True, False])
762763
def test_2d_load_state_dict(self, is_even_sharded_model):
763764
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@@ -811,8 +812,8 @@ def test_2d_load_state_dict(self, is_even_sharded_model):
811812
self.assertEqual(v1.device_mesh, v2.device_mesh)
812813
self.assertEqual(v1.placements, v2.placements)
813814

814-
@with_comms
815815
@skip_if_lt_x_gpu(4)
816+
@with_comms
816817
@parametrize("is_even_sharded_model", [True, False])
817818
def test_2d_optim_state_dict(self, is_even_sharded_model):
818819
simple_model = SimpleModel if is_even_sharded_model else SimpleModelUneven
@@ -899,9 +900,9 @@ def test_2d_optim_state_dict(self, is_even_sharded_model):
899900
else:
900901
self.assertEqual(new_state, state)
901902

903+
@skip_if_lt_x_gpu(4)
902904
@with_comms
903905
@with_temp_dir
904-
@skip_if_lt_x_gpu(4)
905906
def test_fsdp1_tp_2d_set_full_state_dict(self):
906907
"""
907908
This is a workaround for loading full state dict into a FSDP1+TP 2D model.

test/distributed/_composable/test_composability/test_pp_composability.py

Lines changed: 9 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -29,8 +29,8 @@
2929
parallelize_module,
3030
RowwiseParallel,
3131
)
32-
from torch.testing._internal.common_cuda import TEST_MULTIGPU
3332
from torch.testing._internal.common_distributed import (
33+
at_least_x_gpu,
3434
MultiProcessTestCase,
3535
requires_accelerator_dist_backend,
3636
skip_if_lt_x_gpu,
@@ -40,7 +40,6 @@
4040
parametrize,
4141
run_tests,
4242
skip_but_pass_in_sandcastle_if,
43-
TEST_XPU,
4443
)
4544
from torch.testing._internal.distributed.checkpoint_utils import with_temp_dir
4645

@@ -107,11 +106,9 @@ def world_size(self):
107106
def device(self):
108107
return self.rank
109108

110-
@requires_accelerator_dist_backend(["nccl", "xccl"])
109+
@requires_accelerator_dist_backend()
111110
@skip_if_lt_x_gpu(8)
112-
@skip_but_pass_in_sandcastle_if(
113-
not TEST_MULTIGPU and not TEST_XPU, "Test requires 4+ GPUs"
114-
)
111+
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
115112
def test_pp_and_dcp(self):
116113
"""
117114
Test that pipeline parallelism and distributed checkpointing can be used together and
@@ -201,11 +198,9 @@ def _dcp_test(self):
201198

202199
_dcp_test(self)
203200

204-
@requires_accelerator_dist_backend(["nccl", "xccl"])
201+
@requires_accelerator_dist_backend()
205202
@skip_if_lt_x_gpu(8)
206-
@skip_but_pass_in_sandcastle_if(
207-
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
208-
)
203+
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
209204
@parametrize(
210205
"ScheduleClass",
211206
[
@@ -355,11 +350,9 @@ def apply_tp(
355350

356351
torch.distributed.destroy_process_group()
357352

358-
@requires_accelerator_dist_backend(["nccl", "xccl"])
353+
@requires_accelerator_dist_backend()
359354
@skip_if_lt_x_gpu(8)
360-
@skip_but_pass_in_sandcastle_if(
361-
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
362-
)
355+
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
363356
@parametrize(
364357
"ScheduleClass",
365358
[
@@ -550,11 +543,9 @@ def apply_same_precision(partial_model):
550543

551544
torch.distributed.destroy_process_group()
552545

553-
@requires_accelerator_dist_backend(["nccl", "xccl"])
546+
@requires_accelerator_dist_backend()
554547
@skip_if_lt_x_gpu(8)
555-
@skip_but_pass_in_sandcastle_if(
556-
not TEST_MULTIGPU and not TEST_XPU, "Test requires 8+ GPUs"
557-
)
548+
@skip_but_pass_in_sandcastle_if(not at_least_x_gpu(8), "Test requires 8+ GPUs")
558549
@parametrize(
559550
"ScheduleClass",
560551
[

test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py

Lines changed: 20 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
# Owner(s): ["oncall: distributed"]
22

3-
import os
43
import sys
54

65
import torch
@@ -18,8 +17,8 @@
1817
)
1918
from torch.nn.parallel import DistributedDataParallel
2019
from torch.testing._internal.common_distributed import (
21-
MultiProcessTestCase,
22-
requires_nccl,
20+
DistributedTestBase,
21+
requires_accelerator_dist_backend,
2322
skip_if_lt_x_gpu,
2423
)
2524
from torch.testing._internal.common_utils import run_tests, TEST_WITH_DEV_DBG_ASAN
@@ -30,9 +29,12 @@
3029
sys.exit(0)
3130

3231

32+
device_type = acc.type if (acc := torch.accelerator.current_accelerator()) else "cpu"
33+
34+
3335
def gpus_for_rank(world_size):
34-
visible_devices = list(range(torch.cuda.device_count()))
35-
gpus_per_process = torch.cuda.device_count() // world_size
36+
visible_devices = list(range(torch.accelerator.device_count()))
37+
gpus_per_process = torch.accelerator.device_count() // world_size
3638
gpus_for_rank = []
3739
for rank in range(world_size):
3840
gpus_for_rank.append(
@@ -60,27 +62,7 @@ def forward(self, x, rank):
6062
return self.t0(x ** (1 + rank))
6163

6264

63-
class DistributedDataParallelCommHookTest(MultiProcessTestCase):
64-
def setUp(self):
65-
super().setUp()
66-
self._spawn_processes()
67-
68-
def tearDown(self):
69-
try:
70-
os.remove(self.file_name)
71-
except OSError:
72-
pass
73-
74-
def _get_process_group_nccl(self):
75-
store = dist.FileStore(self.file_name, self.world_size)
76-
dist.init_process_group(
77-
backend="nccl",
78-
world_size=self.world_size,
79-
rank=self.rank,
80-
store=store,
81-
)
82-
return dist.distributed_c10d._get_default_group()
83-
65+
class DistributedDataParallelCommHookTest(DistributedTestBase):
8466
@property
8567
def world_size(self):
8668
return 2
@@ -119,14 +101,14 @@ def _run_and_get_grads(self, model):
119101
param = next(model.parameters())
120102
return param.grad
121103

122-
@requires_nccl()
104+
@requires_accelerator_dist_backend()
123105
@skip_if_lt_x_gpu(2)
124106
def test_ddp_comm_hook_allreduce_hook(self):
125107
"""
126108
This unit test verifies the ``allreduce`` hook registered case gives same result
127109
with no hook registered case.
128110
"""
129-
process_group = self._get_process_group_nccl()
111+
process_group = self.create_pg(device_type)
130112

131113
# No hook registered case, get the reference grads.
132114
reference_grads = self._get_grads(process_group, None)
@@ -135,14 +117,14 @@ def test_ddp_comm_hook_allreduce_hook(self):
135117

136118
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
137119

138-
@requires_nccl()
120+
@requires_accelerator_dist_backend()
139121
@skip_if_lt_x_gpu(2)
140122
def test_ddp_comm_hook_fp16compress_hook(self):
141123
"""
142124
This unit test verifies the ``fp16 compress`` hook registered case
143125
gives close result with no hook registered case.
144126
"""
145-
process_group = self._get_process_group_nccl()
127+
process_group = self.create_pg(device_type)
146128

147129
# No hook registered case, get the reference grads.
148130
reference_grads = self._get_grads(process_group, None)
@@ -151,14 +133,14 @@ def test_ddp_comm_hook_fp16compress_hook(self):
151133

152134
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
153135

154-
@requires_nccl()
136+
@requires_accelerator_dist_backend()
155137
@skip_if_lt_x_gpu(2)
156138
def test_ddp_comm_hook_quantize_per_tensor_hook(self):
157139
"""
158140
This unit test verifies the ``quantize per tensor`` hook registered case
159141
gives close result with no hook registered case.
160142
"""
161-
process_group = self._get_process_group_nccl()
143+
process_group = self.create_pg(device_type)
162144

163145
# No hook registered case, get the reference grads.
164146
reference_grads = self._get_grads(process_group, None)
@@ -167,14 +149,14 @@ def test_ddp_comm_hook_quantize_per_tensor_hook(self):
167149

168150
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
169151

170-
@requires_nccl()
152+
@requires_accelerator_dist_backend()
171153
@skip_if_lt_x_gpu(2)
172154
def test_ddp_comm_hook_quantize_per_channel_hook(self):
173155
"""
174156
This unit test verifies the ``quantize per channel`` hook registered case
175157
gives close result with no hook registered case.
176158
"""
177-
process_group = self._get_process_group_nccl()
159+
process_group = self.create_pg(device_type)
178160

179161
# No hook registered case, get the reference grads.
180162
reference_grads = self._get_grads(process_group, None)
@@ -185,14 +167,14 @@ def test_ddp_comm_hook_quantize_per_channel_hook(self):
185167

186168
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=1e-4)
187169

188-
@requires_nccl()
170+
@requires_accelerator_dist_backend()
189171
@skip_if_lt_x_gpu(2)
190172
def test_ddp_comm_hook_noop_hook(self):
191173
"""
192174
This unit test verifies the ``noop`` hook registered case and a subsequent allreduce
193175
gives same result with no hook registered case.
194176
"""
195-
process_group = self._get_process_group_nccl()
177+
process_group = self.create_pg(device_type)
196178

197179
# No hook registered case, get the reference grads.
198180
reference_grads = self._get_grads(process_group, None)
@@ -204,10 +186,10 @@ def test_ddp_comm_hook_noop_hook(self):
204186

205187
torch.testing.assert_close(hook_grads, reference_grads, rtol=1e-5, atol=0)
206188

207-
@requires_nccl()
189+
@requires_accelerator_dist_backend()
208190
@skip_if_lt_x_gpu(2)
209191
def test_is_last_hook(self):
210-
process_group = self._get_process_group_nccl()
192+
process_group = self.create_pg(device_type)
211193

212194
def hook(flags, bucket):
213195
flags.append(bucket.is_last())

0 commit comments

Comments
 (0)