Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions examples/all_reduce.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,9 @@ def dev_array_to_tensor_short(
def one_shot_all_reduce_kernel(
signal_pad_addrs: torch.Tensor,
local_signal_pad: torch.Tensor,
a_shared_tuple: tuple[torch.Tensor, ...],
a_shared: torch.Tensor,
my_rank: hl.constexpr,
group_name: hl.constexpr,
) -> torch.Tensor:
"""
Helion JIT-compiled kernel for one-shot all-reduce operation.
Expand All @@ -113,8 +114,9 @@ def one_shot_all_reduce_kernel(
"""
_, world_size = local_signal_pad.size()
world_size = hl.specialize(world_size)
out = torch.empty_like(a_shared_tuple[0])
out = torch.empty_like(a_shared)
N = out.size(0)
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)

for tile_n in hl.tile(N):
# Sync all devices through signal_pad to make sure
Expand All @@ -139,9 +141,7 @@ def one_shot_all_reduce_kernel(
scope="sys",
)

acc = hl.zeros(
[tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
)
acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)

for a in a_shared_tuple:
acc += a[tile_n]
Expand Down Expand Up @@ -184,15 +184,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
Tensor containing the all-reduced result (sum across all devices)
"""
assert dist.group.WORLD is not None

symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)

a_shared_tuple = tuple(
[
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
for i in range(symm_mem_hdl.world_size)
]
)
group_name = dist.group.WORLD.group_name
symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name)

local_signal_pad = symm_mem_hdl.get_signal_pad(
symm_mem_hdl.rank, dtype=torch.int32
Expand All @@ -208,8 +201,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
return one_shot_all_reduce_kernel(
signal_pad_addrs,
local_signal_pad,
a_shared_tuple,
a_shared,
my_rank=symm_mem_hdl.rank,
group_name=group_name,
)


Expand Down Expand Up @@ -255,15 +249,17 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
rank = dist.get_rank()

# Create symmetric memory tensor for Helion implementation
symm_mem.enable_symm_mem_for_group(dist_group.group_name)
# TODO @kwen2501: no need to divide N
a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_()

print(f"[Rank {rank}] Running Helion all-reduce...")
result_helion = helion_one_shot_all_reduce(a_shared)

# Create symmetric memory tensor for reference implementation
a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device)
a_shared_ref.copy_(a_shared)

print(f"[Rank {rank}] Running Helion all-reduce...")
result_helion = helion_one_shot_all_reduce(a_shared)

print(f"[Rank {rank}] Running reference all-reduce...")
result_ref = reference_one_shot_all_reduce(a_shared_ref)

Expand All @@ -280,6 +276,8 @@ def main() -> None:
Sets up the distributed environment, initializes CUDA devices, and runs the
all-reduce test, and then clean up.
"""
# Only NVSHMEM backend implements `get_remote_tensor` for now.
symm_mem.set_backend("NVSHMEM")
rank = int(os.environ["LOCAL_RANK"])
torch.manual_seed(42 + rank)
device = torch.device(f"cuda:{rank}")
Expand Down
12 changes: 6 additions & 6 deletions test/test_examples_dist.expected
Original file line number Diff line number Diff line change
Expand Up @@ -121,9 +121,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
# src[all_reduce.py:N]: [tile_n.id, world],
# src[all_reduce.py:N-N]: ...
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
# src[all_reduce.py:N]: acc = hl.zeros(
# src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
# src[all_reduce.py:N]: )
# src[all_reduce.py:N]: acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
# src[all_reduce.py:N]: acc += a[tile_n]
load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
Expand Down Expand Up @@ -154,7 +152,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
# src[all_reduce.py:N-N]: ...
helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)

def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher):
def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared: torch.Tensor, my_rank: hl.constexpr, group_name: hl.constexpr, *, _launcher=_default_launcher):
"""
Helion JIT-compiled kernel for one-shot all-reduce operation.

Expand All @@ -173,10 +171,12 @@ def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad:
"""
# src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
_, world_size = local_signal_pad.size()
# src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0])
out = torch.empty_like(a_shared_tuple[0])
# src[all_reduce.py:N]: out = torch.empty_like(a_shared)
out = torch.empty_like(a_shared)
# src[all_reduce.py:N]: N = out.size(0)
N = out.size(0)
# src[all_reduce.py:N]: a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, '0')
# src[all_reduce.py:N]: for tile_n in hl.tile(N):
_BLOCK_SIZE_0 = 8192
_RDIM_SIZE_1 = 4
Expand Down
21 changes: 13 additions & 8 deletions test/test_examples_dist.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,20 +109,19 @@ def test_all_reduce(self):

mod = import_path(EXAMPLES_DIR / "all_reduce.py")

# Only NVSHMEM backend implements `get_remote_tensor` for now.
symm_mem.set_backend("NVSHMEM")
group = dist.group.WORLD
symm_mem.enable_symm_mem_for_group(group.group_name)

N = 16384
dtype = torch.bfloat16

a_shared = symm_mem.empty(
N // self.world_size, dtype=dtype, device=self.device
).normal_()

symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD)
a_shared_tuple = tuple(
[
symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype)
for i in range(symm_mem_hdl.world_size)
]
)
symm_mem_hdl = symm_mem.rendezvous(a_shared, group=group)
local_signal_pad = symm_mem_hdl.get_signal_pad(
symm_mem_hdl.rank, dtype=torch.int32
).view(-1, symm_mem_hdl.world_size)
Expand All @@ -135,7 +134,13 @@ def test_all_reduce(self):

code, result = code_and_output(
mod.one_shot_all_reduce_kernel,
(signal_pad_addrs, local_signal_pad, a_shared_tuple, symm_mem_hdl.rank),
(
signal_pad_addrs,
local_signal_pad,
a_shared,
symm_mem_hdl.rank,
group.group_name,
),
)

if self.rank == 0:
Expand Down
Loading