@@ -155,6 +155,9 @@ def Get_size(self) -> int: ...
155155 @abstractmethod
156156 def allgather (self , data : int ) -> List [int ]: ...
157157
158+ @abstractmethod
159+ def barrier (self ) -> None : ...
160+
158161 @abstractmethod
159162 def Split (self , color : int , key : int ) -> "CommBackend" : ...
160163
@@ -209,6 +212,9 @@ def Get_size(self) -> int:
209212 def allgather (self , data : int ) -> List [int ]:
210213 return self ._mpicomm .allgather (data )
211214
215+ def barrier (self ):
216+ self ._mpicomm .Barrier ()
217+
212218 def Split (self , color : int , key : int ) -> CommBackend :
213219 self ._mpicomm = self ._mpicomm .Split (color , key )
214220 return MPIBackend () # Returns new adapter
@@ -547,14 +553,15 @@ def supports_mnnvl() -> bool:
547553
548554class McastDeviceMemory :
549555 """Python port of McastDeviceMemory from TensorRT-LLM"""
556+
550557 def __init__ (
551558 self ,
552559 buf_size : int ,
553560 group_size : int ,
554561 group_rank : int ,
555562 device_idx : int ,
556563 is_multi_node : bool = True ,
557- comm : Optional [CommBackend ] = None ,
564+ comm_backend_for_handle_transfer : Optional [CommBackend ] = None ,
558565 ):
559566 cu_device = checkCudaErrors (cuda .cuDeviceGet (device_idx ))
560567
@@ -631,7 +638,7 @@ def __init__(
631638 "[McastDeviceMemory] Device does not support fabric handle."
632639 )
633640
634- self ._alloc_mn_mcast_mem (buf_size , comm )
641+ self ._alloc_mn_mcast_mem (buf_size , comm_backend_for_handle_transfer )
635642 else :
636643 # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637644 raise NotImplementedError ("Single-node NVLS allocation not implemented yet" )
@@ -753,7 +760,9 @@ def get_world_size(self) -> int:
753760 """Get the total number of devices in the group"""
754761 return self .group_size
755762
756- def _alloc_mn_mcast_mem (self , buf_size : int , comm : Any = MpiComm ()):
763+ def _alloc_mn_mcast_mem (
764+ self , buf_size : int , comm_backend_for_handle_transfer : Any = None
765+ ):
757766 """Allocate multi-node multicast memory using MNNVL"""
758767
759768 # Verify CUDA context
@@ -766,7 +775,10 @@ def _alloc_mn_mcast_mem(self, buf_size: int, comm: Any=MpiComm()):
766775 )
767776 except Exception as e :
768777 print (f"Error checking CUDA context: { e } " )
769-
778+ if comm_backend_for_handle_transfer is None :
779+ comm = MpiComm ()
780+ else :
781+ comm = comm_backend_for_handle_transfer
770782 # Set up allocation properties
771783 handle_type = cuda .CUmemAllocationHandleType .CU_MEM_HANDLE_TYPE_FABRIC
772784
@@ -966,7 +978,7 @@ def __init__(
966978 group_rank : int ,
967979 device : torch .device ,
968980 mn_nvlink : bool = True ,
969- comm : Optional [CommBackend ] = None ,
981+ comm_backend_for_handle_transfer : Optional [CommBackend ] = None ,
970982 ):
971983 """
972984 Constructor for McastGpuBuffer.
@@ -977,9 +989,15 @@ def __init__(
977989 group_rank: The rank of the local process within the group
978990 device: The CUDA device for buffer allocation
979991 mn_nvlink: Flag indicating if multi-node NVLink is used
992+ comm_backend_for_handle_transfer: Communication backend for handle transfer
980993 """
981994 self .mcast_device_memory = McastDeviceMemory (
982- buf_size , group_size , group_rank , device .index , mn_nvlink , comm
995+ buf_size ,
996+ group_size ,
997+ group_rank ,
998+ device .index ,
999+ mn_nvlink ,
1000+ comm_backend_for_handle_transfer ,
9831001 )
9841002 self .buf_size = buf_size
9851003 self .local_device = device
0 commit comments