Skip to content

Commit 5560eb2

Browse files
committed
fix(qwenimage): Correct context parallelism padding
1 parent a9cb08a commit 5560eb2

File tree

2 files changed

+70
-1
lines changed

2 files changed

+70
-1
lines changed

src/diffusers/models/transformers/transformer_qwenimage.py

Lines changed: 31 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import numpy as np
2020
import torch
21+
import torch.distributed as dist
2122
import torch.nn as nn
2223
import torch.nn.functional as F
2324

@@ -330,6 +331,19 @@ def __call__(
330331
joint_key = torch.cat([txt_key, img_key], dim=1)
331332
joint_value = torch.cat([txt_value, img_value], dim=1)
332333

334+
if attention_mask is None and encoder_hidden_states_mask is not None:
335+
# The joint sequence is [text, image].
336+
seq_len_img = hidden_states.shape[1]
337+
img_mask = torch.ones(
338+
encoder_hidden_states_mask.shape[0], seq_len_img, device=encoder_hidden_states_mask.device
339+
)
340+
attention_mask = torch.cat([encoder_hidden_states_mask, img_mask], dim=1)
341+
342+
# Convert the mask to the format expected by SDPA
343+
attention_mask = attention_mask[:, None, None, :]
344+
attention_mask = attention_mask.to(dtype=joint_query.dtype)
345+
attention_mask = (1.0 - attention_mask) * torch.finfo(joint_query.dtype).min
346+
333347
# Compute joint attention
334348
joint_hidden_states = dispatch_attention_fn(
335349
joint_query,
@@ -600,6 +614,16 @@ def forward(
600614
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
601615
`tuple` where the first element is the sample tensor.
602616
"""
617+
if dist.is_initialized():
618+
world_size = dist.get_world_size()
619+
if world_size > 1 and encoder_hidden_states is not None:
620+
seq_len = encoder_hidden_states.shape[1]
621+
pad_len = (world_size - seq_len % world_size) % world_size
622+
if pad_len > 0:
623+
encoder_hidden_states = F.pad(encoder_hidden_states, (0, 0, 0, pad_len))
624+
if encoder_hidden_states_mask is not None:
625+
encoder_hidden_states_mask = F.pad(encoder_hidden_states_mask, (0, pad_len), value=0)
626+
603627
if attention_kwargs is not None:
604628
attention_kwargs = attention_kwargs.copy()
605629
lora_scale = attention_kwargs.pop("scale", 1.0)
@@ -630,7 +654,13 @@ def forward(
630654
else self.time_text_embed(timestep, guidance, hidden_states)
631655
)
632656

633-
image_rotary_emb = self.pos_embed(img_shapes, txt_seq_lens, device=hidden_states.device)
657+
# use the shape of the padded hidden states to generate the rotary embeddings
658+
if encoder_hidden_states is not None:
659+
recalculated_txt_seq_lens = [encoder_hidden_states.shape[1]] * encoder_hidden_states.shape[0]
660+
else:
661+
recalculated_txt_seq_lens = txt_seq_lens
662+
663+
image_rotary_emb = self.pos_embed(img_shapes, recalculated_txt_seq_lens, device=hidden_states.device)
634664

635665
for index_block, block in enumerate(self.transformer_blocks):
636666
if torch.is_grad_enabled() and self.gradient_checkpointing:

tests/pipelines/qwenimage/test_qwenimage.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
# limitations under the License.
1414

1515
import unittest
16+
from unittest.mock import patch
1617

1718
import numpy as np
1819
import torch
@@ -234,3 +235,41 @@ def test_vae_tiling(self, expected_diff_max: float = 0.2):
234235
expected_diff_max,
235236
"VAE tiling should not affect the inference results",
236237
)
238+
239+
def test_context_parallelism_padding_fix(self):
240+
"""
241+
Compare pipeline outputs: baseline (normal single-process) vs
242+
simulated multi-process (mocked torch.distributed). This verifies
243+
padding logic does not change the generated image.
244+
"""
245+
device = "cpu"
246+
247+
components = self.get_dummy_components()
248+
pipe = self.pipeline_class(**components)
249+
pipe.to(device)
250+
pipe.set_progress_bar_config(disable=None)
251+
252+
inputs = self.get_dummy_inputs(device)
253+
254+
# Baseline run (no distributed)
255+
baseline_image = pipe(**inputs).images[0]
256+
257+
# Re-initialize inputs to get a fresh generator with the same seed for a fair comparison
258+
inputs = self.get_dummy_inputs(device)
259+
260+
# Simulate distributed env (world_size = 3) so padding branch runs
261+
# NOTE: patch target must match where `dist` is imported in the transformer module.
262+
with (
263+
patch("diffusers.models.transformers.transformer_qwenimage.dist.is_initialized", return_value=True),
264+
patch("diffusers.models.transformers.transformer_qwenimage.dist.get_world_size", return_value=3),
265+
):
266+
padded_image = pipe(**inputs).images[0]
267+
268+
# shape check
269+
self.assertEqual(baseline_image.shape, padded_image.shape)
270+
271+
# Additional check: verify padding didn't introduce extreme values
272+
self.assertTrue(torch.isfinite(padded_image).all())
273+
274+
# Verify numerical equivalence
275+
self.assertTrue(torch.allclose(baseline_image, padded_image, atol=1e-2, rtol=1e-2))

0 commit comments

Comments
 (0)