diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..adf35982b602 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -154,7 +154,7 @@ def pre_forward(self, module, *args, **kwargs): # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard # the output instead of input for a particular layer by setting split_output=True if isinstance(input_val, torch.Tensor): - input_val = self._prepare_cp_input(input_val, cpm) + input_val = self._prepare_cp_input(input_val, cpm, name) elif isinstance(input_val, (list, tuple)): if len(input_val) != len(cpm): raise ValueError( @@ -201,14 +201,41 @@ def post_forward(self, module, output): return output[0] if is_tensor else tuple(output) - def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: + def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: logger.warning_once( f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) return x - else: - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + + mesh = self.parallel_config._flattened_mesh + world_size = mesh.size() + dim = cp_input.split_dim + seq_len = x.size(dim) + + if world_size > 1 and seq_len % world_size != 0: + pad_len = (world_size - seq_len % world_size) % world_size + + # Determine pad_value based on name convention + if "mask" in name.lower(): + pad_value = 0 + else: + pad_value = 0.0 + + pad_width = [0] * (2 * x.dim()) + # The pad_width tuple is read from last dim to first dim + pad_idx = x.dim() - 1 - dim + # We want to pad the right side + pad_width[2 * pad_idx + 1] = pad_len + + x = torch.nn.functional.pad(x, tuple(pad_width), mode="constant", value=pad_value) + + # Store original size for trimming in the gather hook + if "hidden_states" in name.lower(): + self.module_forward_metadata._cp_original_s = seq_len + self.module_forward_metadata._cp_pad_dim = dim + + return EquipartitionSharder.shard(x, dim, mesh) class ContextParallelGatherHook(ModelHook): @@ -233,7 +260,20 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + + x = output[i] + x = EquipartitionSharder.unshard(x, cpm.gather_dim, self.parallel_config._flattened_mesh) + + # Trim if padded info exists + if hasattr(module, "_forward_metadata") and hasattr(module._forward_metadata, "_cp_original_s"): + if cpm.gather_dim == module._forward_metadata._cp_pad_dim: + original_s = module._forward_metadata._cp_original_s + x = x.narrow(cpm.gather_dim, 0, original_s) + # Clean up the stored attributes + del module._forward_metadata._cp_original_s + del module._forward_metadata._cp_pad_dim + + output[i] = x return output[0] if is_tensor else tuple(output) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..dec78e4a818d 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -600,6 +600,7 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 8a83f60ff278..5ed92e8180a2 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -14,10 +14,12 @@ import gc import unittest +from unittest.mock import patch import torch from diffusers.hooks import HookRegistry, ModelHook +from diffusers.hooks.context_parallel import ContextParallelSplitHook, EquipartitionSharder from diffusers.training_utils import free_memory from diffusers.utils.logging import get_logger @@ -62,6 +64,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# Small helpers to simulate the parallel_config._flattened_mesh used by the hook +class _DummyMesh: + def __init__(self, size: int): + self._size = size + + def size(self): + return self._size + + +class _DummyParallelConfig: + def __init__(self, mesh_size: int): + self._flattened_mesh = _DummyMesh(mesh_size) + + +# Lightweight object that behaves like a ContextParallelInput for testing. +class _DummyCPInput: + def __init__(self, split_dim: int, expected_dims: int = None, split_output: bool = False): + self.split_dim = split_dim + self.expected_dims = expected_dims + self.split_output = split_output + + class AddHook(ModelHook): def __init__(self, value: int): super().__init__() @@ -375,3 +399,75 @@ def test_invocation_order_stateful_last(self): .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) + + +class ContextParallelHooksTests(unittest.TestCase): + def setUp(self): + # world_size 3 will force padding for seq_len that isn't divisible by 3 + self.parallel_config = _DummyParallelConfig(mesh_size=3) + # metadata may be empty for our direct call to _prepare_cp_input + self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config) + self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1) + # initialize_hook builds module_forward_metadata inside the hook + self.hook.initialize_hook(self.module) + # attach forward metadata to the module exactly how HookRegistry would do + self.module._forward_metadata = self.hook.module_forward_metadata + + def test_prepare_cp_input_pads_hidden_states_and_stores_original(self): + # create a tensor with seq_len = 7 along dim=1 (batch, seq, hidden) + x = torch.randn(1, 7, 16) + + cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False) + + # Patch shard to identity so we can inspect the padded tensor directly + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t) as mock_shard: + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") + + # The hook should have padded seq_len from 7 -> 9 since world_size=3 + self.assertEqual(out.shape[1], 9) + + # ensure shard was called once with the expected dim and mesh + mock_shard.assert_called_once() + called_args, _ = mock_shard.call_args + # called_args = (tensor, dim, mesh) + self.assertEqual(called_args[1], cp_input.split_dim) + self.assertIs(called_args[2], self.parallel_config._flattened_mesh) + + # The hook should have recorded the original sequence length and pad dim + # on the module's metadata so the gather hook can later trim. + self.assertTrue(hasattr(self.module._forward_metadata, "_cp_original_s")) + self.assertTrue(hasattr(self.module._forward_metadata, "_cp_pad_dim")) + self.assertEqual(self.module._forward_metadata._cp_original_s, 7) + self.assertEqual(self.module._forward_metadata._cp_pad_dim, 1) + + def test_prepare_cp_input_pads_attention_mask_with_zeros(self): + # attention masks are typically shape (batch, seq) + # create seq_len = 7 mask with ones + mask = torch.ones(1, 7, dtype=torch.long) + + cp_input = _DummyCPInput(split_dim=1, expected_dims=2, split_output=False) + + # Patch shard to identity + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_attention_mask") + + # After padding it should be shape (1, 9) + self.assertEqual(out_mask.shape[1], 9) + # The padded values should be zeros (pad_value used in code for masks) + # Check the last two positions are zero + padded_portion = out_mask[:, -2:] + self.assertTrue(torch.equal(padded_portion, torch.zeros_like(padded_portion))) + + def test_prepare_cp_input_no_pad_when_divisible(self): + # seq_len is already divisible by world_size (3), e.g., 6 + x = torch.randn(1, 6, 16) + cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False) + + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") + + # no padding should be performed + self.assertEqual(out.shape[1], 6) + # and no _cp_original_s/_cp_pad_dim set because not padded + self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_original_s")) + self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_pad_dim"))