@@ -91,8 +91,9 @@ def dev_array_to_tensor_short(
9191def one_shot_all_reduce_kernel (
9292 signal_pad_addrs : torch .Tensor ,
9393 local_signal_pad : torch .Tensor ,
94- a_shared_tuple : tuple [ torch .Tensor , ...] ,
94+ a_shared : torch .Tensor ,
9595 my_rank : hl .constexpr ,
96+ group_name : hl .constexpr ,
9697) -> torch .Tensor :
9798 """
9899 Helion JIT-compiled kernel for one-shot all-reduce operation.
@@ -112,8 +113,9 @@ def one_shot_all_reduce_kernel(
112113 """
113114 _ , world_size = local_signal_pad .size ()
114115 world_size = hl .specialize (world_size )
115- out = torch .empty_like (a_shared_tuple [ 0 ] )
116+ out = torch .empty_like (a_shared )
116117 N = out .size (0 )
118+ a_shared_tuple = torch .ops .symm_mem .get_remote_tensors (a_shared , group_name )
117119
118120 for tile_n in hl .tile (N ):
119121 # Sync all devices through signal_pad to make sure
@@ -138,9 +140,7 @@ def one_shot_all_reduce_kernel(
138140 scope = "sys" ,
139141 )
140142
141- acc = hl .zeros (
142- [tile_n ], dtype = a_shared_tuple [0 ].dtype , device = local_signal_pad .device
143- )
143+ acc = hl .zeros ([tile_n ], dtype = a_shared .dtype , device = local_signal_pad .device )
144144
145145 for a in a_shared_tuple :
146146 acc += a [tile_n ]
@@ -183,15 +183,8 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
183183 Tensor containing the all-reduced result (sum across all devices)
184184 """
185185 assert dist .group .WORLD is not None
186-
187- symm_mem_hdl = symm_mem .rendezvous (a_shared , group = dist .group .WORLD )
188-
189- a_shared_tuple = tuple (
190- [
191- symm_mem_hdl .get_buffer (i , tuple (a_shared .shape ), a_shared .dtype )
192- for i in range (symm_mem_hdl .world_size )
193- ]
194- )
186+ group_name = dist .group .WORLD .group_name
187+ symm_mem_hdl = symm_mem .rendezvous (a_shared , group_name )
195188
196189 local_signal_pad = symm_mem_hdl .get_signal_pad (
197190 symm_mem_hdl .rank , dtype = torch .int32
@@ -207,8 +200,9 @@ def helion_one_shot_all_reduce(a_shared: torch.Tensor) -> torch.Tensor:
207200 return one_shot_all_reduce_kernel (
208201 signal_pad_addrs ,
209202 local_signal_pad ,
210- a_shared_tuple ,
203+ a_shared ,
211204 my_rank = symm_mem_hdl .rank ,
205+ group_name = group_name ,
212206 )
213207
214208
@@ -254,15 +248,16 @@ def test(N: int, device: torch.device, dtype: torch.dtype) -> None:
254248 rank = dist .get_rank ()
255249
256250 # Create symmetric memory tensor for Helion implementation
251+ symm_mem .enable_symm_mem_for_group (dist .group .WORLD .group_name )
257252 a_shared = symm_mem .empty (N // world_size , dtype = dtype , device = device ).normal_ ()
258253
259- print (f"[Rank { rank } ] Running Helion all-reduce..." )
260- result_helion = helion_one_shot_all_reduce (a_shared )
261-
262254 # Create symmetric memory tensor for reference implementation
263255 a_shared_ref = symm_mem .empty (N // world_size , dtype = dtype , device = device )
264256 a_shared_ref .copy_ (a_shared )
265257
258+ print (f"[Rank { rank } ] Running Helion all-reduce..." )
259+ result_helion = helion_one_shot_all_reduce (a_shared )
260+
266261 print (f"[Rank { rank } ] Running reference all-reduce..." )
267262 result_ref = reference_one_shot_all_reduce (a_shared_ref )
268263
@@ -279,6 +274,8 @@ def main() -> None:
279274 Sets up the distributed environment, initializes CUDA devices, and runs the
280275 all-reduce test, and then clean up.
281276 """
277+ # Only NVSHMEM backend implements `get_remote_tensor` for now.
278+ symm_mem .set_backend ("NVSHMEM" )
282279 rank = int (os .environ ["LOCAL_RANK" ])
283280 torch .manual_seed (42 + rank )
284281 device = torch .device (f"cuda:{ rank } " )
0 commit comments