Skip to content

Commit 47f5294

Browse files
kwen2501yf225
authored andcommitted
Get remote tensors inside Helion kernel
1 parent 913f7c7 commit 47f5294

File tree

1 file changed

+15
-18
lines changed

1 file changed

+15
-18
lines changed

examples/all_reduce.py

Lines changed: 15 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -91,8 +91,9 @@ def dev_array_to_tensor_short(
9191
def one_shot_all_reduce_kernel(
9292
signal_pad_addrs: torch.Tensor,
9393
local_signal_pad: torch.Tensor,
94-
a_shared_tuple: tuple[torch.Tensor, ...],
94+
a_shared: torch.Tensor,
9595
my_rank: hl.constexpr,
96+
group_name: hl.constexpr,
9697
) -> torch.Tensor:
9798
"""
9899
Helion JIT-compiled kernel for one-shot all-reduce operation.
@@ -112,8 +113,9 @@ def one_shot_all_reduce_kernel(
112113
"""
113114
_, world_size = local_signal_pad.size()
114115
world_size = hl.specialize(world_size)
115-
out = torch.empty_like(a_shared_tuple[0])
116+
out = torch.empty_like(a_shared)
116117
N = out.size(0)
118+
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
117119

118120
for tile_n in hl.tile(N):
119121
# Sync all devices through signal_pad to make sure
@@ -138,9 +140,7 @@ def one_shot_all_reduce_kernel(
138140
scope="sys",
139141
)
140142

141-
acc = hl.zeros(
142-
[tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
143-
)
143+
acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
144144

145145
for a in a_shared_tuple:
146146
acc += a[tile_n]
@@ -183,15 +183,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
183183
Tensor containing the all-reduced result (sum across all devices)
184184
"""
185185
assert dist.group.WORLD is not None
186-
187-
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
188-
189-
a_shared_tuple = tuple(
190-
[
191-
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
192-
for i in range(symm_mem_hdl.world_size)
193-
]
194-
)
186+
group_name = dist.group.WORLD.group_name
187+
symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name)
195188

196189
local_signal_pad = symm_mem_hdl.get_signal_pad(
197190
symm_mem_hdl.rank, dtype=torch.int32
@@ -207,8 +200,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
207200
return one_shot_all_reduce_kernel(
208201
signal_pad_addrs,
209202
local_signal_pad,
210-
a_shared_tuple,
203+
a_shared,
211204
my_rank=symm_mem_hdl.rank,
205+
group_name=group_name,
212206
)
213207

214208

@@ -254,15 +248,16 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
254248
rank = dist.get_rank()
255249

256250
# Create symmetric memory tensor for Helion implementation
251+
symm_mem.enable_symm_mem_for_group(dist.group.WORLD.group_name)
257252
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()
258253

259-
print(f"[Rank {rank}] Running Helion all-reduce...")
260-
result_helion = helion_one_shot_all_reduce(a_shared)
261-
262254
# Create symmetric memory tensor for reference implementation
263255
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
264256
a_shared_ref.copy_(a_shared)
265257

258+
print(f"[Rank {rank}] Running Helion all-reduce...")
259+
result_helion = helion_one_shot_all_reduce(a_shared)
260+
266261
print(f"[Rank {rank}] Running reference all-reduce...")
267262
result_ref = reference_one_shot_all_reduce(a_shared_ref)
268263

@@ -279,6 +274,8 @@ def main() -> None:
279274
Sets up the distributed environment, initializes CUDA devices, and runs the
280275
all-reduce test, and then clean up.
281276
"""
277+
# Only NVSHMEM backend implements `get_remote_tensor` for now.
278+
symm_mem.set_backend("NVSHMEM")
282279
rank = int(os.environ["LOCAL_RANK"])
283280
torch.manual_seed(42 + rank)
284281
device = torch.device(f"cuda:{rank}")

0 commit comments

Comments
 (0)