@@ -547,14 +547,14 @@ def supports_mnnvl() -> bool:
547547
548548class McastDeviceMemory :
549549 """Python port of McastDeviceMemory from TensorRT-LLM"""
550-
551550 def __init__ (
552551 self ,
553552 buf_size : int ,
554553 group_size : int ,
555554 group_rank : int ,
556555 device_idx : int ,
557556 is_multi_node : bool = True ,
557+ comm : Optional [CommBackend ] = None ,
558558 ):
559559 cu_device = checkCudaErrors (cuda .cuDeviceGet (device_idx ))
560560
@@ -631,7 +631,7 @@ def __init__(
631631 "[McastDeviceMemory] Device does not support fabric handle."
632632 )
633633
634- self ._alloc_mn_mcast_mem (buf_size )
634+ self ._alloc_mn_mcast_mem (buf_size , comm )
635635 else :
636636 # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637637 raise NotImplementedError ("Single-node NVLS allocation not implemented yet" )
@@ -753,7 +753,7 @@ def get_world_size(self) -> int:
753753 """Get the total number of devices in the group"""
754754 return self .group_size
755755
756- def _alloc_mn_mcast_mem (self , buf_size : int ):
756+ def _alloc_mn_mcast_mem (self , buf_size : int , comm : Any = MpiComm () ):
757757 """Allocate multi-node multicast memory using MNNVL"""
758758
759759 # Verify CUDA context
@@ -767,9 +767,6 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
767767 except Exception as e :
768768 print (f"Error checking CUDA context: { e } " )
769769
770- # Get MPI communicator
771- comm = MpiComm ()
772-
773770 # Set up allocation properties
774771 handle_type = cuda .CUmemAllocationHandleType .CU_MEM_HANDLE_TYPE_FABRIC
775772
@@ -969,6 +966,7 @@ def __init__(
969966 group_rank : int ,
970967 device : torch .device ,
971968 mn_nvlink : bool = True ,
969+ comm : Optional [CommBackend ] = None ,
972970 ):
973971 """
974972 Constructor for McastGpuBuffer.
@@ -981,7 +979,7 @@ def __init__(
981979 mn_nvlink: Flag indicating if multi-node NVLink is used
982980 """
983981 self .mcast_device_memory = McastDeviceMemory (
984- buf_size , group_size , group_rank , device .index , mn_nvlink
982+ buf_size , group_size , group_rank , device .index , mn_nvlink , comm
985983 )
986984 self .buf_size = buf_size
987985 self .local_device = device
0 commit comments