Skip to content

Commit 198f1cb

Browse files
kwen2501yf225
andauthored
Get remote tensors inside @helion.kernel (#1122)
Co-authored-by: Will Feng <yfeng.us@gmail.com>
1 parent b9925d9 commit 198f1cb

File tree

3 files changed

+35
-32
lines changed

3 files changed

+35
-32
lines changed

examples/all_reduce.py

Lines changed: 16 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,9 @@ def dev_array_to_tensor_short(
9292
def one_shot_all_reduce_kernel(
9393
signal_pad_addrs: torch.Tensor,
9494
local_signal_pad: torch.Tensor,
95-
a_shared_tuple: tuple[torch.Tensor, ...],
95+
a_shared: torch.Tensor,
9696
my_rank: hl.constexpr,
97+
group_name: hl.constexpr,
9798
) -> torch.Tensor:
9899
"""
99100
Helion JIT-compiled kernel for one-shot all-reduce operation.
@@ -113,8 +114,9 @@ def one_shot_all_reduce_kernel(
113114
"""
114115
_, world_size = local_signal_pad.size()
115116
world_size = hl.specialize(world_size)
116-
out = torch.empty_like(a_shared_tuple[0])
117+
out = torch.empty_like(a_shared)
117118
N = out.size(0)
119+
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
118120

119121
for tile_n in hl.tile(N):
120122
# Sync all devices through signal_pad to make sure
@@ -139,9 +141,7 @@ def one_shot_all_reduce_kernel(
139141
scope="sys",
140142
)
141143

142-
acc = hl.zeros(
143-
[tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
144-
)
144+
acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
145145

146146
for a in a_shared_tuple:
147147
acc += a[tile_n]
@@ -184,15 +184,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
184184
Tensor containing the all-reduced result (sum across all devices)
185185
"""
186186
assert dist.group.WORLD is not None
187-
188-
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
189-
190-
a_shared_tuple = tuple(
191-
[
192-
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
193-
for i in range(symm_mem_hdl.world_size)
194-
]
195-
)
187+
group_name = dist.group.WORLD.group_name
188+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name)
196189

197190
local_signal_pad = symm_mem_hdl.get_signal_pad(
198191
symm_mem_hdl.rank, dtype=torch.int32
@@ -208,8 +201,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
208201
return one_shot_all_reduce_kernel(
209202
signal_pad_addrs,
210203
local_signal_pad,
211-
a_shared_tuple,
204+
a_shared,
212205
my_rank=symm_mem_hdl.rank,
206+
group_name=group_name,
213207
)
214208

215209

@@ -255,15 +249,17 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
255249
rank = dist.get_rank()
256250

257251
# Create symmetric memory tensor for Helion implementation
252+
symm_mem.enable_symm_mem_for_group(dist_group.group_name)
253+
# TODO @kwen2501: no need to divide N
258254
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
259255

260-
print(f"[Rank {rank}] Running Helion all-reduce...")
261-
result_helion = helion_one_shot_all_reduce(a_shared)
262-
263256
# Create symmetric memory tensor for reference implementation
264257
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
265258
a_shared_ref.copy_(a_shared)
266259

260+
print(f"[Rank {rank}] Running Helion all-reduce...")
261+
result_helion = helion_one_shot_all_reduce(a_shared)
262+
267263
print(f"[Rank {rank}] Running reference all-reduce...")
268264
result_ref = reference_one_shot_all_reduce(a_shared_ref)
269265

@@ -280,6 +276,8 @@ def main() -> None:
280276
Sets up the distributed environment, initializes CUDA devices, and runs the
281277
all-reduce test, and then clean up.
282278
"""
279+
# Only NVSHMEM backend implements `get_remote_tensor` for now.
280+
symm_mem.set_backend("NVSHMEM")
283281
rank = int(os.environ["LOCAL_RANK"])
284282
torch.manual_seed(42 + rank)
285283
device = torch.device(f"cuda:{rank}")

test/test_examples_dist.expected

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -121,9 +121,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
121121
# src[all_reduce.py:N]: [tile_n.id, world],
122122
# src[all_reduce.py:N-N]: ...
123123
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
124-
# src[all_reduce.py:N]: acc = hl.zeros(
125-
# src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
126-
# src[all_reduce.py:N]: )
124+
# src[all_reduce.py:N]: acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
127125
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
128126
# src[all_reduce.py:N]: acc += a[tile_n]
129127
load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
@@ -154,7 +152,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
154152
# src[all_reduce.py:N-N]: ...
155153
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)
156154

157-
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher):
155+
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared: torch.Tensor, my_rank: hl.constexpr, group_name: hl.constexpr, *, _launcher=_default_launcher):
158156
"""
159157
Helion JIT-compiled kernel for one-shot all-reduce operation.
160158

@@ -173,10 +171,12 @@ def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad:
173171
"""
174172
# src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
175173
_, world_size = local_signal_pad.size()
176-
# src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0])
177-
out = torch.empty_like(a_shared_tuple[0])
174+
# src[all_reduce.py:N]: out = torch.empty_like(a_shared)
175+
out = torch.empty_like(a_shared)
178176
# src[all_reduce.py:N]: N = out.size(0)
179177
N = out.size(0)
178+
# src[all_reduce.py:N]: a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
179+
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, '0')
180180
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
181181
_BLOCK_SIZE_0 = 8192
182182
_RDIM_SIZE_1 = 4

test/test_examples_dist.py

Lines changed: 13 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -109,20 +109,19 @@ def test_all_reduce(self):
109109

110110
mod = import_path(EXAMPLES_DIR / "all_reduce.py")
111111

112+
# Only NVSHMEM backend implements `get_remote_tensor` for now.
113+
symm_mem.set_backend("NVSHMEM")
114+
group = dist.group.WORLD
115+
symm_mem.enable_symm_mem_for_group(group.group_name)
116+
112117
N = 16384
113118
dtype = torch.bfloat16
114119

115120
a_shared = symm_mem.empty(
116121
N // self.world_size, dtype=dtype, device=self.device
117122
).normal_()
118123

119-
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
120-
a_shared_tuple = tuple(
121-
[
122-
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
123-
for i in range(symm_mem_hdl.world_size)
124-
]
125-
)
124+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=group)
126125
local_signal_pad = symm_mem_hdl.get_signal_pad(
127126
symm_mem_hdl.rank, dtype=torch.int32
128127
).view(-1, symm_mem_hdl.world_size)
@@ -135,7 +134,13 @@ def test_all_reduce(self):
135134

136135
code, result = code_and_output(
137136
mod.one_shot_all_reduce_kernel,
138-
(signal_pad_addrs, local_signal_pad, a_shared_tuple, symm_mem_hdl.rank),
137+
(
138+
signal_pad_addrs,
139+
local_signal_pad,
140+
a_shared,
141+
symm_mem_hdl.rank,
142+
group.group_name,
143+
),
139144
)
140145

141146
if self.rank == 0:

0 commit comments

Comments
 (0)