diff --git a/examples/all_reduce.py b/examples/all_reduce.py index afb27f326..9178c9298 100644 --- a/examples/all_reduce.py +++ b/examples/all_reduce.py @@ -92,8 +92,9 @@ def dev_array_to_tensor_short( def one_shot_all_reduce_kernel( signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, - a_shared_tuple: tuple[torch.Tensor, ...], + a_shared: torch.Tensor, my_rank: hl.constexpr, + group_name: hl.constexpr, ) -> torch.Tensor: """ Helion JIT-compiled kernel for one-shot all-reduce operation. @@ -113,8 +114,9 @@ def one_shot_all_reduce_kernel( """ _, world_size = local_signal_pad.size() world_size = hl.specialize(world_size) - out = torch.empty_like(a_shared_tuple[0]) + out = torch.empty_like(a_shared) N = out.size(0) + a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name) for tile_n in hl.tile(N): # Sync all devices through signal_pad to make sure @@ -139,9 +141,7 @@ def one_shot_all_reduce_kernel( scope="sys", ) - acc = hl.zeros( - [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device - ) + acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device) for a in a_shared_tuple: acc += a[tile_n] @@ -184,15 +184,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: Tensor containing the all-reduced result (sum across all devices) """ assert dist.group.WORLD is not None - - symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) - - a_shared_tuple = tuple( - [ - symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) - for i in range(symm_mem_hdl.world_size) - ] - ) + group_name = dist.group.WORLD.group_name + symm_mem_hdl = symm_mem.rendezvous(a_shared, group_name) local_signal_pad = symm_mem_hdl.get_signal_pad( symm_mem_hdl.rank, dtype=torch.int32 @@ -208,8 +201,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor: return one_shot_all_reduce_kernel( signal_pad_addrs, local_signal_pad, - a_shared_tuple, + a_shared, my_rank=symm_mem_hdl.rank, + group_name=group_name, ) @@ -255,15 +249,17 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None: rank = dist.get_rank() # Create symmetric memory tensor for Helion implementation + symm_mem.enable_symm_mem_for_group(dist_group.group_name) + # TODO @kwen2501: no need to divide N a_shared = symm_mem.empty(N // world_size, dtype=dtype, device=device).normal_() - print(f"[Rank {rank}] Running Helion all-reduce...") - result_helion = helion_one_shot_all_reduce(a_shared) - # Create symmetric memory tensor for reference implementation a_shared_ref = symm_mem.empty(N // world_size, dtype=dtype, device=device) a_shared_ref.copy_(a_shared) + print(f"[Rank {rank}] Running Helion all-reduce...") + result_helion = helion_one_shot_all_reduce(a_shared) + print(f"[Rank {rank}] Running reference all-reduce...") result_ref = reference_one_shot_all_reduce(a_shared_ref) @@ -280,6 +276,8 @@ def main() -> None: Sets up the distributed environment, initializes CUDA devices, and runs the all-reduce test, and then clean up. """ + # Only NVSHMEM backend implements `get_remote_tensor` for now. + symm_mem.set_backend("NVSHMEM") rank = int(os.environ["LOCAL_RANK"]) torch.manual_seed(42 + rank) device = torch.device(f"cuda:{rank}") diff --git a/test/test_examples_dist.expected b/test/test_examples_dist.expected index dadbf473a..2c31cac93 100644 --- a/test/test_examples_dist.expected +++ b/test/test_examples_dist.expected @@ -121,9 +121,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha # src[all_reduce.py:N]: [tile_n.id, world], # src[all_reduce.py:N-N]: ... helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False) - # src[all_reduce.py:N]: acc = hl.zeros( - # src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device - # src[all_reduce.py:N]: ) + # src[all_reduce.py:N]: acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device) acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16) # src[all_reduce.py:N]: acc += a[tile_n] load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0) @@ -154,7 +152,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha # src[all_reduce.py:N-N]: ... helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True) -def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[torch.Tensor, ...], my_rank: hl.constexpr, *, _launcher=_default_launcher): +def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared: torch.Tensor, my_rank: hl.constexpr, group_name: hl.constexpr, *, _launcher=_default_launcher): """ Helion JIT-compiled kernel for one-shot all-reduce operation. @@ -173,10 +171,12 @@ def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: """ # src[all_reduce.py:N]: _, world_size = local_signal_pad.size() _, world_size = local_signal_pad.size() - # src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0]) - out = torch.empty_like(a_shared_tuple[0]) + # src[all_reduce.py:N]: out = torch.empty_like(a_shared) + out = torch.empty_like(a_shared) # src[all_reduce.py:N]: N = out.size(0) N = out.size(0) + # src[all_reduce.py:N]: a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name) + a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, '0') # src[all_reduce.py:N]: for tile_n in hl.tile(N): _BLOCK_SIZE_0 = 8192 _RDIM_SIZE_1 = 4 diff --git a/test/test_examples_dist.py b/test/test_examples_dist.py index 54afaa172..dfa00c704 100644 --- a/test/test_examples_dist.py +++ b/test/test_examples_dist.py @@ -109,6 +109,11 @@ def test_all_reduce(self): mod = import_path(EXAMPLES_DIR / "all_reduce.py") + # Only NVSHMEM backend implements `get_remote_tensor` for now. + symm_mem.set_backend("NVSHMEM") + group = dist.group.WORLD + symm_mem.enable_symm_mem_for_group(group.group_name) + N = 16384 dtype = torch.bfloat16 @@ -116,13 +121,7 @@ def test_all_reduce(self): N // self.world_size, dtype=dtype, device=self.device ).normal_() - symm_mem_hdl = symm_mem.rendezvous(a_shared, group=dist.group.WORLD) - a_shared_tuple = tuple( - [ - symm_mem_hdl.get_buffer(i, tuple(a_shared.shape), a_shared.dtype) - for i in range(symm_mem_hdl.world_size) - ] - ) + symm_mem_hdl = symm_mem.rendezvous(a_shared, group=group) local_signal_pad = symm_mem_hdl.get_signal_pad( symm_mem_hdl.rank, dtype=torch.int32 ).view(-1, symm_mem_hdl.world_size) @@ -135,7 +134,13 @@ def test_all_reduce(self): code, result = code_and_output( mod.one_shot_all_reduce_kernel, - (signal_pad_addrs, local_signal_pad, a_shared_tuple, symm_mem_hdl.rank), + ( + signal_pad_addrs, + local_signal_pad, + a_shared, + symm_mem_hdl.rank, + group.group_name, + ), ) if self.rank == 0: