Skip to content

Commit 93340a2

Browse files
committed
fix(hooks): Add generic padding to context parallel hook
1 parent baf42db commit 93340a2

File tree

4 files changed

+142
-73
lines changed

4 files changed

+142
-73
lines changed

src/diffusers/hooks/context_parallel.py

Lines changed: 45 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -154,7 +154,7 @@ def pre_forward(self, module, *args, **kwargs):
154154
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
155155
# the output instead of input for a particular layer by setting split_output=True
156156
if isinstance(input_val, torch.Tensor):
157-
input_val = self._prepare_cp_input(input_val, cpm)
157+
input_val = self._prepare_cp_input(input_val, cpm, name)
158158
elif isinstance(input_val, (list, tuple)):
159159
if len(input_val) != len(cpm):
160160
raise ValueError(
@@ -201,12 +201,40 @@ def post_forward(self, module, output):
201201

202202
return output[0] if is_tensor else tuple(output)
203203

204-
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
204+
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str) -> torch.Tensor:
205205
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
206206
raise ValueError(
207207
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions."
208208
)
209-
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)
209+
210+
mesh = self.parallel_config._flattened_mesh
211+
world_size = mesh.size()
212+
dim = cp_input.split_dim
213+
seq_len = x.size(dim)
214+
215+
if world_size > 1 and seq_len % world_size != 0:
216+
pad_len = (world_size - seq_len % world_size) % world_size
217+
218+
# Determine pad_value based on name convention
219+
if "mask" in name.lower():
220+
pad_value = 0
221+
else:
222+
pad_value = 0.0
223+
224+
pad_width = [0] * (2 * x.dim())
225+
# The pad_width tuple is read from last dim to first dim
226+
pad_idx = x.dim() - 1 - dim
227+
# We want to pad the right side
228+
pad_width[2 * pad_idx + 1] = pad_len
229+
230+
x = torch.nn.functional.pad(x, tuple(pad_width), mode="constant", value=pad_value)
231+
232+
# Store original size for trimming in the gather hook
233+
if "hidden_states" in name.lower():
234+
self.module_forward_metadata._cp_original_s = seq_len
235+
self.module_forward_metadata._cp_pad_dim = dim
236+
237+
return EquipartitionSharder.shard(x, dim, mesh)
210238

211239

212240
class ContextParallelGatherHook(ModelHook):
@@ -231,7 +259,20 @@ def post_forward(self, module, output):
231259
for i, cpm in enumerate(self.metadata):
232260
if cpm is None:
233261
continue
234-
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)
262+
263+
x = output[i]
264+
x = EquipartitionSharder.unshard(x, cpm.gather_dim, self.parallel_config._flattened_mesh)
265+
266+
# Trim if padded info exists
267+
if hasattr(module, "_forward_metadata") and hasattr(module._forward_metadata, "_cp_original_s"):
268+
if cpm.gather_dim == module._forward_metadata._cp_pad_dim:
269+
original_s = module._forward_metadata._cp_original_s
270+
x = x.narrow(cpm.gather_dim, 0, original_s)
271+
# Clean up the stored attributes
272+
del module._forward_metadata._cp_original_s
273+
del module._forward_metadata._cp_pad_dim
274+
275+
output[i] = x
235276

236277
return output[0] if is_tensor else tuple(output)
237278

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 1 addition & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818

1919
import numpy as np
2020
import torch
21-
import torch.distributed as dist
2221
import torch.nn as nn
2322
import torch.nn.functional as F
2423

@@ -331,19 +330,6 @@ def __call__(
331330
joint_key = torch.cat([txt_key, img_key], dim=1)
332331
joint_value = torch.cat([txt_value, img_value], dim=1)
333332

334-
if attention_mask is None and encoder_hidden_states_mask is not None:
335-
# The joint sequence is [text, image].
336-
seq_len_img = hidden_states.shape[1]
337-
img_mask = torch.ones(
338-
encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device
339-
)
340-
attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1)
341-
342-
# Convert the mask to the format expected by SDPA
343-
attention_mask = attention_mask[:, None, None, :]
344-
attention_mask = attention_mask.to(dtype=joint_query.dtype)
345-
attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min
346-
347333
# Compute joint attention
348334
joint_hidden_states = dispatch_attention_fn(
349335
joint_query,
@@ -614,15 +600,6 @@ def forward(
614600
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
615601
`tuple` where the first element is the sample tensor.
616602
"""
617-
if dist.is_initialized():
618-
world_size = dist.get_world_size()
619-
if world_size > 1 and encoder_hidden_states is not None:
620-
seq_len = encoder_hidden_states.shape[1]
621-
pad_len = (world_size - seq_len % world_size) % world_size
622-
if pad_len > 0:
623-
encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len))
624-
if encoder_hidden_states_mask is not None:
625-
encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0)
626603

627604
if attention_kwargs is not None:
628605
attention_kwargs = attention_kwargs.copy()
@@ -654,13 +631,7 @@ def forward(
654631
else self.time_text_embed(timestep, guidance, hidden_states)
655632
)
656633

657-
# use the shape of the padded hidden states to generate the rotary embeddings
658-
if encoder_hidden_states is not None:
659-
recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
660-
else:
661-
recalculated_txt_seq_lens = txt_seq_lens
662-
663-
image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device)
634+
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
664635

665636
for index_block, block in enumerate(self.transformer_blocks):
666637
if torch.is_grad_enabled() and self.gradient_checkpointing:

tests/hooks/test_hooks.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@
1414

1515
import gc
1616
import unittest
17+
from unittest.mock import patch
1718

1819
import torch
1920

2021
from diffusers.hooks import HookRegistry, ModelHook
22+
from diffusers.hooks.context_parallel import ContextParallelSplitHook, EquipartitionSharder
2123
from diffusers.training_utils import free_memory
2224
from diffusers.utils.logging import get_logger
2325

@@ -62,6 +64,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
6264
return x
6365

6466

67+
# Small helpers to simulate the parallel_config._flattened_mesh used by the hook
68+
class _DummyMesh:
69+
def __init__(self, size: int):
70+
self._size = size
71+
72+
def size(self):
73+
return self._size
74+
75+
76+
class _DummyParallelConfig:
77+
def __init__(self, mesh_size: int):
78+
self._flattened_mesh = _DummyMesh(mesh_size)
79+
80+
81+
# Lightweight object that behaves like a ContextParallelInput for testing.
82+
class _DummyCPInput:
83+
def __init__(self, split_dim: int, expected_dims: int = None, split_output: bool = False):
84+
self.split_dim = split_dim
85+
self.expected_dims = expected_dims
86+
self.split_output = split_output
87+
88+
6589
class AddHook(ModelHook):
6690
def __init__(self, value: int):
6791
super().__init__()
@@ -375,3 +399,75 @@ def test_invocation_order_stateful_last(self):
375399
.replace("\n", "")
376400
)
377401
self.assertEqual(output, expected_invocation_order_log)
402+
403+
404+
class ContextParallelHooksTests(unittest.TestCase):
405+
def setUp(self):
406+
# world_size 3 will force padding for seq_len that isn't divisible by 3
407+
self.parallel_config = _DummyParallelConfig(mesh_size=3)
408+
# metadata may be empty for our direct call to _prepare_cp_input
409+
self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config)
410+
self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1)
411+
# initialize_hook builds module_forward_metadata inside the hook
412+
self.hook.initialize_hook(self.module)
413+
# attach forward metadata to the module exactly how HookRegistry would do
414+
self.module._forward_metadata = self.hook.module_forward_metadata
415+
416+
def test_prepare_cp_input_pads_hidden_states_and_stores_original(self):
417+
# create a tensor with seq_len = 7 along dim=1 (batch, seq, hidden)
418+
x = torch.randn(1, 7, 16)
419+
420+
cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False)
421+
422+
# Patch shard to identity so we can inspect the padded tensor directly
423+
with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t) as mock_shard:
424+
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")
425+
426+
# The hook should have padded seq_len from 7 -> 9 since world_size=3
427+
self.assertEqual(out.shape[1], 9)
428+
429+
# ensure shard was called once with the expected dim and mesh
430+
mock_shard.assert_called_once()
431+
called_args, _ = mock_shard.call_args
432+
# called_args = (tensor, dim, mesh)
433+
self.assertEqual(called_args[1], cp_input.split_dim)
434+
self.assertIs(called_args[2], self.parallel_config._flattened_mesh)
435+
436+
# The hook should have recorded the original sequence length and pad dim
437+
# on the module's metadata so the gather hook can later trim.
438+
self.assertTrue(hasattr(self.module._forward_metadata, "_cp_original_s"))
439+
self.assertTrue(hasattr(self.module._forward_metadata, "_cp_pad_dim"))
440+
self.assertEqual(self.module._forward_metadata._cp_original_s, 7)
441+
self.assertEqual(self.module._forward_metadata._cp_pad_dim, 1)
442+
443+
def test_prepare_cp_input_pads_attention_mask_with_zeros(self):
444+
# attention masks are typically shape (batch, seq)
445+
# create seq_len = 7 mask with ones
446+
mask = torch.ones(1, 7, dtype=torch.long)
447+
448+
cp_input = _DummyCPInput(split_dim=1, expected_dims=2, split_output=False)
449+
450+
# Patch shard to identity
451+
with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
452+
out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_attention_mask")
453+
454+
# After padding it should be shape (1, 9)
455+
self.assertEqual(out_mask.shape[1], 9)
456+
# The padded values should be zeros (pad_value used in code for masks)
457+
# Check the last two positions are zero
458+
padded_portion = out_mask[:, -2:]
459+
self.assertTrue(torch.equal(padded_portion, torch.zeros_like(padded_portion)))
460+
461+
def test_prepare_cp_input_no_pad_when_divisible(self):
462+
# seq_len is already divisible by world_size (3), e.g., 6
463+
x = torch.randn(1, 6, 16)
464+
cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False)
465+
466+
with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
467+
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")
468+
469+
# no padding should be performed
470+
self.assertEqual(out.shape[1], 6)
471+
# and no _cp_original_s/_cp_pad_dim set because not padded
472+
self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_original_s"))
473+
self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_pad_dim"))

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 0 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,6 @@
1313
# limitations under the License.
1414

1515
import unittest
16-
from unittest.mock import patch
1716

1817
import numpy as np
1918
import torch
@@ -235,41 +234,3 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
235234
expected_diff_max,
236235
"VAE tiling should not affect the inference results",
237236
)
238-
239-
def test_context_parallelism_padding_fix(self):
240-
"""
241-
Compare pipeline outputs: baseline (normal single-process) vs
242-
simulated multi-process (mocked torch.distributed). This verifies
243-
padding logic does not change the generated image.
244-
"""
245-
device = "cpu"
246-
247-
components = self.get_dummy_components()
248-
pipe = self.pipeline_class(**components)
249-
pipe.to(device)
250-
pipe.set_progress_bar_config(disable=None)
251-
252-
inputs = self.get_dummy_inputs(device)
253-
254-
# Baseline run (no distributed)
255-
baseline_image = pipe(**inputs).images[0]
256-
257-
# Re-initialize inputs to get a fresh generator with the same seed for a fair comparison
258-
inputs = self.get_dummy_inputs(device)
259-
260-
# Simulate distributed env (world_size = 3) so padding branch runs
261-
# NOTE: patch target must match where `dist` is imported in the transformer module.
262-
with (
263-
patch("diffusers.models.transformers.transformer_qwenimage.dist.is_initialized", return_value=True),
264-
patch("diffusers.models.transformers.transformer_qwenimage.dist.get_world_size", return_value=3),
265-
):
266-
padded_image = pipe(**inputs).images[0]
267-
268-
# shape check
269-
self.assertEqual(baseline_image.shape, padded_image.shape)
270-
271-
# Additional check: verify padding didn't introduce extreme values
272-
self.assertTrue(torch.isfinite(padded_image).all())
273-
274-
# Verify numerical equivalence
275-
self.assertTrue(torch.allclose(baseline_image, padded_image, atol=1e-2, rtol=1e-2))

0 commit comments

Comments
 (0)