From 5560eb206c4e7e72f18695f1266bf96c7decff96 Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Wed, 5 Nov 2025 18:25:11 +0400 Subject: [PATCH 1/2] fix(qwenimage): Correct context parallelism padding --- .../transformers/transformer_qwenimage.py | 32 ++++++++++++++- tests/pipelines/qwenimage/test_qwenimage.py | 39 +++++++++++++++++++ 2 files changed, 70 insertions(+), 1 deletion(-) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index c0fa031b9faf..992abddb72a9 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -18,6 +18,7 @@ import numpy as np import torch +import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -330,6 +331,19 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) + if attention_mask is None and encoder_hidden_states_mask is not None: + # The joint sequence is [text, image]. + seq_len_img = hidden_states.shape[1] + img_mask = torch.ones( + encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device + ) + attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1) + + # Convert the mask to the format expected by SDPA + attention_mask = attention_mask[:, None, None, :] + attention_mask = attention_mask.to(dtype=joint_query.dtype) + attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min + # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -600,6 +614,16 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ + if dist.is_initialized(): + world_size = dist.get_world_size() + if world_size > 1 and encoder_hidden_states is not None: + seq_len = encoder_hidden_states.shape[1] + pad_len = (world_size - seq_len % world_size) % world_size + if pad_len > 0: + encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len)) + if encoder_hidden_states_mask is not None: + encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0) + if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() lora_scale = attention_kwargs.pop("scale", 1.0) @@ -630,7 +654,13 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) + # use the shape of the padded hidden states to generate the rotary embeddings + if encoder_hidden_states is not None: + recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] + else: + recalculated_txt_seq_lens = txt_seq_lens + + image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py index 8ebfe7d08bc1..6a574a66d442 100644 --- a/tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/pipelines/qwenimage/test_qwenimage.py @@ -13,6 +13,7 @@ # limitations under the License. import unittest +from unittest.mock import patch import numpy as np import torch @@ -234,3 +235,41 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): expected_diff_max, "VAE tiling should not affect the inference results", ) + + def test_context_parallelism_padding_fix(self): + """ + Compare pipeline outputs: baseline (normal single-process) vs + simulated multi-process (mocked torch.distributed). This verifies + padding logic does not change the generated image. + """ + device = "cpu" + + components = self.get_dummy_components() + pipe = self.pipeline_class(**components) + pipe.to(device) + pipe.set_progress_bar_config(disable=None) + + inputs = self.get_dummy_inputs(device) + + # Baseline run (no distributed) + baseline_image = pipe(**inputs).images[0] + + # Re-initialize inputs to get a fresh generator with the same seed for a fair comparison + inputs = self.get_dummy_inputs(device) + + # Simulate distributed env (world_size = 3) so padding branch runs + # NOTE: patch target must match where `dist` is imported in the transformer module. + with ( + patch("diffusers.models.transformers.transformer_qwenimage.dist.is_initialized", return_value=True), + patch("diffusers.models.transformers.transformer_qwenimage.dist.get_world_size", return_value=3), + ): + padded_image = pipe(**inputs).images[0] + + # shape check + self.assertEqual(baseline_image.shape, padded_image.shape) + + # Additional check: verify padding didn't introduce extreme values + self.assertTrue(torch.isfinite(padded_image).all()) + + # Verify numerical equivalence + self.assertTrue(torch.allclose(baseline_image, padded_image, atol=1e-2, rtol=1e-2)) From afc18a1d90cf9941cc60d6db73449ec0f82ebf2a Mon Sep 17 00:00:00 2001 From: Ratish1 Date: Sat, 8 Nov 2025 23:08:05 +0400 Subject: [PATCH 2/2] fix(hooks): Add generic padding to context parallel hook --- src/diffusers/hooks/context_parallel.py | 50 +++++++++- .../transformers/transformer_qwenimage.py | 31 +----- tests/hooks/test_hooks.py | 96 +++++++++++++++++++ tests/pipelines/qwenimage/test_qwenimage.py | 39 -------- 4 files changed, 142 insertions(+), 74 deletions(-) diff --git a/src/diffusers/hooks/context_parallel.py b/src/diffusers/hooks/context_parallel.py index 6491d17b4f46..adf35982b602 100644 --- a/src/diffusers/hooks/context_parallel.py +++ b/src/diffusers/hooks/context_parallel.py @@ -154,7 +154,7 @@ def pre_forward(self, module, *args, **kwargs): # The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard # the output instead of input for a particular layer by setting split_output=True if isinstance(input_val, torch.Tensor): - input_val = self._prepare_cp_input(input_val, cpm) + input_val = self._prepare_cp_input(input_val, cpm, name) elif isinstance(input_val, (list, tuple)): if len(input_val) != len(cpm): raise ValueError( @@ -201,14 +201,41 @@ def post_forward(self, module, output): return output[0] if is_tensor else tuple(output) - def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor: + def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str) -> torch.Tensor: if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims: logger.warning_once( f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied." ) return x - else: - return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh) + + mesh = self.parallel_config._flattened_mesh + world_size = mesh.size() + dim = cp_input.split_dim + seq_len = x.size(dim) + + if world_size > 1 and seq_len % world_size != 0: + pad_len = (world_size - seq_len % world_size) % world_size + + # Determine pad_value based on name convention + if "mask" in name.lower(): + pad_value = 0 + else: + pad_value = 0.0 + + pad_width = [0] * (2 * x.dim()) + # The pad_width tuple is read from last dim to first dim + pad_idx = x.dim() - 1 - dim + # We want to pad the right side + pad_width[2 * pad_idx + 1] = pad_len + + x = torch.nn.functional.pad(x, tuple(pad_width), mode="constant", value=pad_value) + + # Store original size for trimming in the gather hook + if "hidden_states" in name.lower(): + self.module_forward_metadata._cp_original_s = seq_len + self.module_forward_metadata._cp_pad_dim = dim + + return EquipartitionSharder.shard(x, dim, mesh) class ContextParallelGatherHook(ModelHook): @@ -233,7 +260,20 @@ def post_forward(self, module, output): for i, cpm in enumerate(self.metadata): if cpm is None: continue - output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh) + + x = output[i] + x = EquipartitionSharder.unshard(x, cpm.gather_dim, self.parallel_config._flattened_mesh) + + # Trim if padded info exists + if hasattr(module, "_forward_metadata") and hasattr(module._forward_metadata, "_cp_original_s"): + if cpm.gather_dim == module._forward_metadata._cp_pad_dim: + original_s = module._forward_metadata._cp_original_s + x = x.narrow(cpm.gather_dim, 0, original_s) + # Clean up the stored attributes + del module._forward_metadata._cp_original_s + del module._forward_metadata._cp_pad_dim + + output[i] = x return output[0] if is_tensor else tuple(output) diff --git a/src/diffusers/models/transformers/transformer_qwenimage.py b/src/diffusers/models/transformers/transformer_qwenimage.py index 992abddb72a9..dec78e4a818d 100644 --- a/src/diffusers/models/transformers/transformer_qwenimage.py +++ b/src/diffusers/models/transformers/transformer_qwenimage.py @@ -18,7 +18,6 @@ import numpy as np import torch -import torch.distributed as dist import torch.nn as nn import torch.nn.functional as F @@ -331,19 +330,6 @@ def __call__( joint_key = torch.cat([txt_key, img_key], dim=1) joint_value = torch.cat([txt_value, img_value], dim=1) - if attention_mask is None and encoder_hidden_states_mask is not None: - # The joint sequence is [text, image]. - seq_len_img = hidden_states.shape[1] - img_mask = torch.ones( - encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device - ) - attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1) - - # Convert the mask to the format expected by SDPA - attention_mask = attention_mask[:, None, None, :] - attention_mask = attention_mask.to(dtype=joint_query.dtype) - attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min - # Compute joint attention joint_hidden_states = dispatch_attention_fn( joint_query, @@ -614,15 +600,6 @@ def forward( If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a `tuple` where the first element is the sample tensor. """ - if dist.is_initialized(): - world_size = dist.get_world_size() - if world_size > 1 and encoder_hidden_states is not None: - seq_len = encoder_hidden_states.shape[1] - pad_len = (world_size - seq_len % world_size) % world_size - if pad_len > 0: - encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len)) - if encoder_hidden_states_mask is not None: - encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0) if attention_kwargs is not None: attention_kwargs = attention_kwargs.copy() @@ -654,13 +631,7 @@ def forward( else self.time_text_embed(timestep, guidance, hidden_states) ) - # use the shape of the padded hidden states to generate the rotary embeddings - if encoder_hidden_states is not None: - recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0] - else: - recalculated_txt_seq_lens = txt_seq_lens - - image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device) + image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device) for index_block, block in enumerate(self.transformer_blocks): if torch.is_grad_enabled() and self.gradient_checkpointing: diff --git a/tests/hooks/test_hooks.py b/tests/hooks/test_hooks.py index 8a83f60ff278..5ed92e8180a2 100644 --- a/tests/hooks/test_hooks.py +++ b/tests/hooks/test_hooks.py @@ -14,10 +14,12 @@ import gc import unittest +from unittest.mock import patch import torch from diffusers.hooks import HookRegistry, ModelHook +from diffusers.hooks.context_parallel import ContextParallelSplitHook, EquipartitionSharder from diffusers.training_utils import free_memory from diffusers.utils.logging import get_logger @@ -62,6 +64,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: return x +# Small helpers to simulate the parallel_config._flattened_mesh used by the hook +class _DummyMesh: + def __init__(self, size: int): + self._size = size + + def size(self): + return self._size + + +class _DummyParallelConfig: + def __init__(self, mesh_size: int): + self._flattened_mesh = _DummyMesh(mesh_size) + + +# Lightweight object that behaves like a ContextParallelInput for testing. +class _DummyCPInput: + def __init__(self, split_dim: int, expected_dims: int = None, split_output: bool = False): + self.split_dim = split_dim + self.expected_dims = expected_dims + self.split_output = split_output + + class AddHook(ModelHook): def __init__(self, value: int): super().__init__() @@ -375,3 +399,75 @@ def test_invocation_order_stateful_last(self): .replace("\n", "") ) self.assertEqual(output, expected_invocation_order_log) + + +class ContextParallelHooksTests(unittest.TestCase): + def setUp(self): + # world_size 3 will force padding for seq_len that isn't divisible by 3 + self.parallel_config = _DummyParallelConfig(mesh_size=3) + # metadata may be empty for our direct call to _prepare_cp_input + self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config) + self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1) + # initialize_hook builds module_forward_metadata inside the hook + self.hook.initialize_hook(self.module) + # attach forward metadata to the module exactly how HookRegistry would do + self.module._forward_metadata = self.hook.module_forward_metadata + + def test_prepare_cp_input_pads_hidden_states_and_stores_original(self): + # create a tensor with seq_len = 7 along dim=1 (batch, seq, hidden) + x = torch.randn(1, 7, 16) + + cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False) + + # Patch shard to identity so we can inspect the padded tensor directly + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t) as mock_shard: + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") + + # The hook should have padded seq_len from 7 -> 9 since world_size=3 + self.assertEqual(out.shape[1], 9) + + # ensure shard was called once with the expected dim and mesh + mock_shard.assert_called_once() + called_args, _ = mock_shard.call_args + # called_args = (tensor, dim, mesh) + self.assertEqual(called_args[1], cp_input.split_dim) + self.assertIs(called_args[2], self.parallel_config._flattened_mesh) + + # The hook should have recorded the original sequence length and pad dim + # on the module's metadata so the gather hook can later trim. + self.assertTrue(hasattr(self.module._forward_metadata, "_cp_original_s")) + self.assertTrue(hasattr(self.module._forward_metadata, "_cp_pad_dim")) + self.assertEqual(self.module._forward_metadata._cp_original_s, 7) + self.assertEqual(self.module._forward_metadata._cp_pad_dim, 1) + + def test_prepare_cp_input_pads_attention_mask_with_zeros(self): + # attention masks are typically shape (batch, seq) + # create seq_len = 7 mask with ones + mask = torch.ones(1, 7, dtype=torch.long) + + cp_input = _DummyCPInput(split_dim=1, expected_dims=2, split_output=False) + + # Patch shard to identity + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_attention_mask") + + # After padding it should be shape (1, 9) + self.assertEqual(out_mask.shape[1], 9) + # The padded values should be zeros (pad_value used in code for masks) + # Check the last two positions are zero + padded_portion = out_mask[:, -2:] + self.assertTrue(torch.equal(padded_portion, torch.zeros_like(padded_portion))) + + def test_prepare_cp_input_no_pad_when_divisible(self): + # seq_len is already divisible by world_size (3), e.g., 6 + x = torch.randn(1, 6, 16) + cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False) + + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") + + # no padding should be performed + self.assertEqual(out.shape[1], 6) + # and no _cp_original_s/_cp_pad_dim set because not padded + self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_original_s")) + self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_pad_dim")) diff --git a/tests/pipelines/qwenimage/test_qwenimage.py b/tests/pipelines/qwenimage/test_qwenimage.py index 6a574a66d442..8ebfe7d08bc1 100644 --- a/tests/pipelines/qwenimage/test_qwenimage.py +++ b/tests/pipelines/qwenimage/test_qwenimage.py @@ -13,7 +13,6 @@ # limitations under the License. import unittest -from unittest.mock import patch import numpy as np import torch @@ -235,41 +234,3 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2): expected_diff_max, "VAE tiling should not affect the inference results", ) - - def test_context_parallelism_padding_fix(self): - """ - Compare pipeline outputs: baseline (normal single-process) vs - simulated multi-process (mocked torch.distributed). This verifies - padding logic does not change the generated image. - """ - device = "cpu" - - components = self.get_dummy_components() - pipe = self.pipeline_class(**components) - pipe.to(device) - pipe.set_progress_bar_config(disable=None) - - inputs = self.get_dummy_inputs(device) - - # Baseline run (no distributed) - baseline_image = pipe(**inputs).images[0] - - # Re-initialize inputs to get a fresh generator with the same seed for a fair comparison - inputs = self.get_dummy_inputs(device) - - # Simulate distributed env (world_size = 3) so padding branch runs - # NOTE: patch target must match where `dist` is imported in the transformer module. - with ( - patch("diffusers.models.transformers.transformer_qwenimage.dist.is_initialized", return_value=True), - patch("diffusers.models.transformers.transformer_qwenimage.dist.get_world_size", return_value=3), - ): - padded_image = pipe(**inputs).images[0] - - # shape check - self.assertEqual(baseline_image.shape, padded_image.shape) - - # Additional check: verify padding didn't introduce extreme values - self.assertTrue(torch.isfinite(padded_image).all()) - - # Verify numerical equivalence - self.assertTrue(torch.allclose(baseline_image, padded_image, atol=1e-2, rtol=1e-2))