Skip to content

Commit 3885454

Browse files
committed
Update tests
1 parent 47f5294 commit 3885454

File tree

2 files changed

+19
-14
lines changed

2 files changed

+19
-14
lines changed

test/test_examples_dist.expected

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
121121
# src[all_reduce.py:N]: [tile_n.id, world],
122122
# src[all_reduce.py:N-N]: ...
123123
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
124-
# src[all_reduce.py:N]: acc = hl.zeros(
125-
# src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
126-
# src[all_reduce.py:N]: )
124+
# src[all_reduce.py:N]: acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
127125
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
128126
# src[all_reduce.py:N]: acc += a[tile_n]
129127
load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
@@ -154,7 +152,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
154152
# src[all_reduce.py:N-N]: ...
155153
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)
156154

157-
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher):
155+
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared: torch.Tensor, my_rank: hl.constexpr, group_name: hl.constexpr, *, _launcher=_default_launcher):
158156
"""
159157
Helion JIT-compiled kernel for one-shot all-reduce operation.
160158

@@ -173,10 +171,12 @@ def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad:
173171
"""
174172
# src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
175173
_, world_size = local_signal_pad.size()
176-
# src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0])
177-
out = torch.empty_like(a_shared_tuple[0])
174+
# src[all_reduce.py:N]: out = torch.empty_like(a_shared)
175+
out = torch.empty_like(a_shared)
178176
# src[all_reduce.py:N]: N = out.size(0)
179177
N = out.size(0)
178+
# src[all_reduce.py:N]: a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
179+
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, '0')
180180
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
181181
_BLOCK_SIZE_0 = 8192
182182
_RDIM_SIZE_1 = 4

test/test_examples_dist.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -106,20 +106,19 @@ def test_all_reduce(self):
106106

107107
mod = import_path(EXAMPLES_DIR / "all_reduce.py")
108108

109+
# Only NVSHMEM backend implements `get_remote_tensor` for now.
110+
symm_mem.set_backend("NVSHMEM")
111+
group = dist.group.WORLD
112+
symm_mem.enable_symm_mem_for_group(group.group_name)
113+
109114
N = 16384
110115
dtype = torch.bfloat16
111116

112117
a_shared = symm_mem.empty(
113118
N // self.world_size, dtype=dtype, device=self.device
114119
).normal_()
115120

116-
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
117-
a_shared_tuple = tuple(
118-
[
119-
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
120-
for i in range(symm_mem_hdl.world_size)
121-
]
122-
)
121+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=group)
123122
local_signal_pad = symm_mem_hdl.get_signal_pad(
124123
symm_mem_hdl.rank, dtype=torch.int32
125124
).view(-1, symm_mem_hdl.world_size)
@@ -132,7 +131,13 @@ def test_all_reduce(self):
132131

133132
code, result = code_and_output(
134133
mod.one_shot_all_reduce_kernel,
135-
(signal_pad_addrs, local_signal_pad, a_shared_tuple, symm_mem_hdl.rank),
134+
(
135+
signal_pad_addrs,
136+
local_signal_pad,
137+
a_shared,
138+
symm_mem_hdl.rank,
139+
group.group_name,
140+
),
136141
)
137142

138143
if self.rank == 0:

0 commit comments

Comments
 (0)