@@ -663,14 +663,29 @@ def send_tensor_dict(
663663 tensor_dict : dict [str , Union [torch .Tensor , Any ]],
664664 dst : Optional [int ] = None ,
665665 all_gather_group : Optional ["GroupCoordinator" ] = None ,
666+ all_gather_tensors : Optional [dict [str , bool ]] = None ,
666667 ) -> Optional [dict [str , Union [torch .Tensor , Any ]]]:
667668 """Send the input tensor dictionary.
668669 NOTE: `dst` is the local rank of the source rank.
670+
671+ all_gather_group: The group for the all-gather operation. If provided,
672+ an optimization is enabled where each rank in the group sends a
673+ slice of a tensor and the receiver reconstructs it using an
674+ all-gather, which can improve performance. This is typically the
675+ tensor-parallel group.
676+ all_gather_tensors: A dictionary to specify which tensors should use
677+ the all-gather optimization, which is only effective when
678+ `all_gather_group` is provided. By default, this optimization is
679+ on for any tensor whose size is divisible by the
680+ `all_gather_group`'s world size. However, it should be disabled
681+ for tensors that are not fully replicated across the group (e.g.,
682+ the residual tensor when sequence parallelism is enabled). This
683+ dictionary allows overriding the default behavior on a per-tensor
684+ basis.
669685 """
670686 # Bypass the function if we are using only 1 GPU.
671687 if not torch .distributed .is_initialized () or self .world_size == 1 :
672688 return tensor_dict
673-
674689 all_gather_size = (1 if all_gather_group is None else
675690 all_gather_group .world_size )
676691 all_gather_rank = (0 if all_gather_group is None else
@@ -699,14 +714,23 @@ def send_tensor_dict(
699714 # `send_object_list` has serialization & deserialization,
700715 # all happening on CPU. Therefore, we can use the CPU group.
701716 self .send_object (metadata_list , dst = dst )
702- for tensor in tensor_list :
717+
718+ tensor_keys = [
719+ k for k , v in tensor_dict .items () if isinstance (v , torch .Tensor )
720+ ]
721+ assert len (tensor_keys ) == len (tensor_list )
722+
723+ for key , tensor in zip (tensor_keys , tensor_list ):
703724 if tensor .numel () == 0 :
704725 # Skip sending empty tensors.
705726 continue
706727
707728 # send-allgather: send only a slice, then do allgather.
708- if (all_gather_group is not None
709- and tensor .numel () % all_gather_size == 0 ):
729+ use_all_gather = (all_gather_group is not None
730+ and tensor .numel () % all_gather_size == 0 )
731+ use_all_gather = all_gather_tensors .get (key , use_all_gather ) \
732+ if all_gather_tensors else use_all_gather
733+ if use_all_gather :
710734 tensor = tensor .reshape (all_gather_size , - 1 )[all_gather_rank ]
711735
712736 if tensor .is_cpu :
@@ -725,14 +749,29 @@ def recv_tensor_dict(
725749 self ,
726750 src : Optional [int ] = None ,
727751 all_gather_group : Optional ["GroupCoordinator" ] = None ,
752+ all_gather_tensors : Optional [dict [str , bool ]] = None ,
728753 ) -> Optional [dict [str , Union [torch .Tensor , Any ]]]:
729754 """Recv the input tensor dictionary.
730755 NOTE: `src` is the local rank of the source rank.
756+
757+ all_gather_group: The group for the all-gather operation. If provided,
758+ an optimization is enabled where each rank in the group sends a
759+ slice of a tensor and the receiver reconstructs it using an
760+ all-gather, which can improve performance. This is typically the
761+ tensor-parallel group.
762+ all_gather_tensors: A dictionary to specify which tensors should use
763+ the all-gather optimization, which is only effective when
764+ `all_gather_group` is provided. By default, this optimization is
765+ on for any tensor whose size is divisible by the
766+ `all_gather_group`'s world size. However, it should be disabled
767+ for tensors that are not fully replicated across the group (e.g.,
768+ the residual tensor when sequence parallelism is enabled). This
769+ dictionary allows overriding the default behavior on a per-tensor
770+ basis.
731771 """
732772 # Bypass the function if we are using only 1 GPU.
733773 if not torch .distributed .is_initialized () or self .world_size == 1 :
734774 return None
735-
736775 all_gather_size = (1 if all_gather_group is None else
737776 all_gather_group .world_size )
738777 all_gather_rank = (0 if all_gather_group is None else
@@ -766,6 +805,8 @@ def recv_tensor_dict(
766805 # send-allgather: send only a slice, then do allgather.
767806 use_all_gather = (all_gather_group is not None
768807 and tensor .numel () % all_gather_size == 0 )
808+ use_all_gather = all_gather_tensors .get (key , use_all_gather ) \
809+ if all_gather_tensors else use_all_gather
769810
770811 if use_all_gather :
771812 orig_shape = tensor .shape
0 commit comments