@@ -121,9 +121,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
121121 # src[all_reduce.py:N]: [tile_n.id, world],
122122 # src[all_reduce.py:N-N]: ...
123123 helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_1 * 4 + indices_2 * 1), expect=1, update=0, sem='acquire', scope='sys', op='atomic_cas', skip_sync=False)
124- # src[all_reduce.py:N]: acc = hl.zeros(
125- # src[all_reduce.py:N]: [tile_n], dtype=a_shared_tuple[0].dtype, device=local_signal_pad.device
126- # src[all_reduce.py:N]: )
124+ # src[all_reduce.py:N]: acc = hl.zeros([tile_n], dtype=a_shared.dtype, device=local_signal_pad.device)
127125 acc = tl.full([_BLOCK_SIZE_0], 0.0, tl.bfloat16)
128126 # src[all_reduce.py:N]: acc += a[tile_n]
129127 load_1 = tl.load(a_shared_tuple_item_0 + indices_0 * 1, mask_0, other=0)
@@ -154,7 +152,7 @@ def _helion_one_shot_all_reduce_kernel(signal_pad_addrs, local_signal_pad, a_sha
154152 # src[all_reduce.py:N-N]: ...
155153 helion.runtime.triton_wait_multiple_signal(addr=local_signal_pad + (tile_id_2 * 4 + indices_3 * 1), expect=1, update=0, sem='relaxed', scope='sys', op='atomic_cas', skip_sync=True)
156154
157- def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared_tuple: tuple[ torch.Tensor, ...], my_rank : hl.constexpr, *, _launcher=_default_launcher):
155+ def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad: torch.Tensor, a_shared: torch.Tensor, my_rank: hl.constexpr, group_name : hl.constexpr, *, _launcher=_default_launcher):
158156 """
159157 Helion JIT-compiled kernel for one-shot all-reduce operation.
160158
@@ -173,10 +171,12 @@ def one_shot_all_reduce_kernel(signal_pad_addrs: torch.Tensor, local_signal_pad:
173171 """
174172 # src[all_reduce.py:N]: _, world_size = local_signal_pad.size()
175173 _, world_size = local_signal_pad.size()
176- # src[all_reduce.py:N]: out = torch.empty_like(a_shared_tuple[0] )
177- out = torch.empty_like(a_shared_tuple[0] )
174+ # src[all_reduce.py:N]: out = torch.empty_like(a_shared )
175+ out = torch.empty_like(a_shared )
178176 # src[all_reduce.py:N]: N = out.size(0)
179177 N = out.size(0)
178+ # src[all_reduce.py:N]: a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, group_name)
179+ a_shared_tuple = torch.ops.symm_mem.get_remote_tensors(a_shared, '0')
180180 # src[all_reduce.py:N]: for tile_n in hl.tile(N):
181181 _BLOCK_SIZE_0 = 8192
182182 _RDIM_SIZE_1 = 4
0 commit comments