@@ -547,14 +547,15 @@ def supports_mnnvl() -> bool:
547547
548548class McastDeviceMemory :
549549 """Python port of McastDeviceMemory from TensorRT-LLM"""
550-
550+ # config: Optional[MnnvlConfig] = None
551551 def __init__ (
552552 self ,
553553 buf_size : int ,
554554 group_size : int ,
555555 group_rank : int ,
556556 device_idx : int ,
557557 is_multi_node : bool = True ,
558+ comm : Optional [CommBackend ] = None ,
558559 ):
559560 cu_device = checkCudaErrors (cuda .cuDeviceGet (device_idx ))
560561
@@ -631,7 +632,7 @@ def __init__(
631632 "[McastDeviceMemory] Device does not support fabric handle."
632633 )
633634
634- self ._alloc_mn_mcast_mem (buf_size )
635+ self ._alloc_mn_mcast_mem (buf_size , comm )
635636 else :
636637 # For single-node NVLS, would need to implement _alloc_nvls_mcast_mem
637638 raise NotImplementedError ("Single-node NVLS allocation not implemented yet" )
@@ -649,6 +650,7 @@ def __init__(
649650 self .signal_pads_dev = alloc_and_copy_to_cuda (self .signal_pads )
650651 self .uc_ptrs_dev = alloc_and_copy_to_cuda (self .uc_ptrs )
651652
653+
652654 def __del__ (self ):
653655 """Destructor - cleanup allocated memory"""
654656
@@ -753,7 +755,7 @@ def get_world_size(self) -> int:
753755 """Get the total number of devices in the group"""
754756 return self .group_size
755757
756- def _alloc_mn_mcast_mem (self , buf_size : int ):
758+ def _alloc_mn_mcast_mem (self , buf_size : int , comm : Any = MpiComm () ):
757759 """Allocate multi-node multicast memory using MNNVL"""
758760
759761 # Verify CUDA context
@@ -768,7 +770,12 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
768770 print (f"Error checking CUDA context: { e } " )
769771
770772 # Get MPI communicator
771- comm = MpiComm ()
773+ # comm = MpiComm()
774+ # comm = McastDeviceMemory.get_comm()
775+ # if config:
776+ # comm = config.comm_backend
777+ # else:
778+ # comm = MpiComm()
772779
773780 # Set up allocation properties
774781 handle_type = cuda .CUmemAllocationHandleType .CU_MEM_HANDLE_TYPE_FABRIC
@@ -831,6 +838,9 @@ def _alloc_mn_mcast_mem(self, buf_size: int):
831838 )
832839
833840 # All-gather fabric handles
841+ print (my_fabric_handle .data )
842+ print (type (my_fabric_handle .data ))
843+ # all_fabric_handles=[my_fabric_handle.data] * 4
834844 all_fabric_handles = comm .allgather (my_fabric_handle .data )
835845 cuda .cuCtxSynchronize ()
836846
@@ -969,6 +979,7 @@ def __init__(
969979 group_rank : int ,
970980 device : torch .device ,
971981 mn_nvlink : bool = True ,
982+ comm : Optional [CommBackend ] = None ,
972983 ):
973984 """
974985 Constructor for McastGpuBuffer.
@@ -981,7 +992,7 @@ def __init__(
981992 mn_nvlink: Flag indicating if multi-node NVLink is used
982993 """
983994 self .mcast_device_memory = McastDeviceMemory (
984- buf_size , group_size , group_rank , device .index , mn_nvlink
995+ buf_size , group_size , group_rank , device .index , mn_nvlink , comm
985996 )
986997 self .buf_size = buf_size
987998 self .local_device = device
0 commit comments