@@ -92,8 +92,9 @@ def dev_array_to_tensor_short(
9292def one_shot_all_reduce_kernel (
9393 signal_pad_addrs : torch .Tensor ,
9494 local_signal_pad : torch .Tensor ,
95- a_shared_tuple : tuple [ torch .Tensor , ...] ,
95+ a_shared : torch .Tensor ,
9696 my_rank : hl .constexpr ,
97+ group_name : hl .constexpr ,
9798) -> torch .Tensor :
9899 """
99100 Helion JIT-compiled kernel for one-shot all-reduce operation.
@@ -113,8 +114,9 @@ def one_shot_all_reduce_kernel(
113114 """
114115 _ , world_size = local_signal_pad .size ()
115116 world_size = hl .specialize (world_size )
116- out = torch .empty_like (a_shared_tuple [ 0 ] )
117+ out = torch .empty_like (a_shared )
117118 N = out .size (0 )
119+ a_shared_tuple = torch .ops .symm_mem .get_remote_tensors (a_shared , group_name )
118120
119121 for tile_n in hl .tile (N ):
120122 # Sync all devices through signal_pad to make sure
@@ -139,9 +141,7 @@ def one_shot_all_reduce_kernel(
139141 scope = "sys" ,
140142 )
141143
142- acc = hl .zeros (
143- [tile_n ], dtype = a_shared_tuple [0 ].dtype , device = local_signal_pad .device
144- )
144+ acc = hl .zeros ([tile_n ], dtype = a_shared .dtype , device = local_signal_pad .device )
145145
146146 for a in a_shared_tuple :
147147 acc += a [tile_n ]
@@ -184,15 +184,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
184184 Tensor containing the all-reduced result (sum across all devices)
185185 """
186186 assert dist .group .WORLD is not None
187-
188- symm_mem_hdl = symm_mem .rendezvous (a_shared , group = dist .group .WORLD )
189-
190- a_shared_tuple = tuple (
191- [
192- symm_mem_hdl .get_buffer (i , tuple (a_shared .shape ), a_shared .dtype )
193- for i in range (symm_mem_hdl .world_size )
194- ]
195- )
187+ group_name = dist .group .WORLD .group_name
188+ symm_mem_hdl = symm_mem .rendezvous (a_shared , group_name )
196189
197190 local_signal_pad = symm_mem_hdl .get_signal_pad (
198191 symm_mem_hdl .rank , dtype = torch .int32
@@ -208,8 +201,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
208201 return one_shot_all_reduce_kernel (
209202 signal_pad_addrs ,
210203 local_signal_pad ,
211- a_shared_tuple ,
204+ a_shared ,
212205 my_rank = symm_mem_hdl .rank ,
206+ group_name = group_name ,
213207 )
214208
215209
@@ -255,15 +249,17 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
255249 rank = dist .get_rank ()
256250
257251 # Create symmetric memory tensor for Helion implementation
252+ symm_mem .enable_symm_mem_for_group (dist_group .group_name )
253+ # TODO @kwen2501: no need to divide N
258254 a_shared = symm_mem .empty (N // world_size , dtype = dtype , device = device ).normal_ ()
259255
260- print (f"[Rank { rank } ] Running Helion all-reduce..." )
261- result_helion = helion_one_shot_all_reduce (a_shared )
262-
263256 # Create symmetric memory tensor for reference implementation
264257 a_shared_ref = symm_mem .empty (N // world_size , dtype = dtype , device = device )
265258 a_shared_ref .copy_ (a_shared )
266259
260+ print (f"[Rank { rank } ] Running Helion all-reduce..." )
261+ result_helion = helion_one_shot_all_reduce (a_shared )
262+
267263 print (f"[Rank { rank } ] Running reference all-reduce..." )
268264 result_ref = reference_one_shot_all_reduce (a_shared_ref )
269265
@@ -280,6 +276,8 @@ def main() -> None:
280276 Sets up the distributed environment, initializes CUDA devices, and runs the
281277 all-reduce test, and then clean up.
282278 """
279+ # Only NVSHMEM backend implements `get_remote_tensor` for now.
280+ symm_mem .set_backend ("NVSHMEM" )
283281 rank = int (os .environ ["LOCAL_RANK" ])
284282 torch .manual_seed (42 + rank )
285283 device = torch .device (f"cuda:{ rank } " )
0 commit comments