Skip to content
13 changes: 7 additions & 6 deletions src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,7 @@ def expected_components(self) -> List[ComponentSpec]:
@property
def inputs(self) -> List[InputParam]:
return [
InputParam("latents"),
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similar to how it's done in the other pipelines.

InputParam(name="height"),
InputParam(name="width"),
InputParam(name="num_images_per_prompt", default=1),
Expand Down Expand Up @@ -196,11 +197,11 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
f"You have passed a list of generators of length {len(block_state.generator)}, but requested an effective batch"
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
)

block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)
if block_state.latents is None:
block_state.latents = randn_tensor(
shape, generator=block_state.generator, device=device, dtype=block_state.dtype
)
block_state.latents = components.pachifier.pack_latents(block_state.latents)

self.set_block_state(state, block_state)
return components, state
Expand Down Expand Up @@ -549,7 +550,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
block_state.width // components.vae_scale_factor // 2,
)
]
* block_state.batch_size
for _ in range(block_state.batch_size)
Copy link
Member Author

@sayakpaul sayakpaul Nov 4, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have two options:

  1. Either this
  2. Or how it's done in edit:
    img_shapes = [
    [
    (1, height // self.vae_scale_factor // 2, width // self.vae_scale_factor // 2),
    (1, calculated_height // self.vae_scale_factor // 2, calculated_width // self.vae_scale_factor // 2),
    ]
    ] * batch_size

Regardless, the current implementation isn't exactly the same as how the standard pipeline implements it and would break for the batched input tests we have.

]
block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
Expand Down
3 changes: 2 additions & 1 deletion src/diffusers/modular_pipelines/qwenimage/decoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
block_state = self.get_block_state(state)

# YiYi Notes: remove support for output_type = "latents', we can just skip decode/encode step in modular
vae_scale_factor = 2 ** len(components.vae.temperal_downsample)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Keeping vae_scale_factor fixed to 8, for example, would break the tests as we use a smaller VAE.

block_state.latents = components.pachifier.unpack_latents(
block_state.latents, block_state.height, block_state.width
block_state.latents, block_state.height, block_state.width, vae_scale_factor=vae_scale_factor
)
block_state.latents = block_state.latents.to(components.vae.dtype)

Expand Down
6 changes: 6 additions & 0 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,6 +503,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
block_state.prompt_embeds = block_state.prompt_embeds[:, : block_state.max_sequence_length]
block_state.prompt_embeds_mask = block_state.prompt_embeds_mask[:, : block_state.max_sequence_length]

block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
Comment on lines +506 to +507
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Otherwise, no CFG settings would break.

if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or ""
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds(
Expand Down Expand Up @@ -627,6 +629,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
device=device,
)

block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = get_qwen_prompt_embeds_edit(
Expand Down Expand Up @@ -679,6 +683,8 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
device=device,
)

block_state.negative_prompt_embeds = None
block_state.negative_prompt_embeds_mask = None
if components.requires_unconditional_embeds:
negative_prompt = block_state.negative_prompt or " "
block_state.negative_prompt_embeds, block_state.negative_prompt_embeds_mask = (
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,7 @@ class QwenImagePachifier(ConfigMixin):
config_name = "config.json"

@register_to_config
def __init__(
self,
patch_size: int = 2,
):
def __init__(self, patch_size: int = 2):
super().__init__()

def pack_latents(self, latents):
Expand Down
Empty file.
141 changes: 141 additions & 0 deletions tests/modular_pipelines/qwen/test_modular_pipeline_qwenimage.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,141 @@
# coding=utf-8
# Copyright 2025 HuggingFace Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.


import numpy as np
import PIL
import pytest

from diffusers import ClassifierFreeGuidance
from diffusers.modular_pipelines import (
QwenImageAutoBlocks,
QwenImageEditAutoBlocks,
QwenImageEditModularPipeline,
QwenImageEditPlusAutoBlocks,
QwenImageEditPlusModularPipeline,
QwenImageModularPipeline,
)

from ...testing_utils import torch_device
from ..test_modular_pipelines_common import ModularPipelineTesterMixin


class QwenImageModularGuiderMixin:
def test_guider_cfg(self, tol=1e-2):
pipe = self.get_pipeline()
pipe = pipe.to(torch_device)

guider = ClassifierFreeGuidance(guidance_scale=1.0)
pipe.update_components(guider=guider)

inputs = self.get_dummy_inputs()
out_no_cfg = pipe(**inputs, output="images")

guider = ClassifierFreeGuidance(guidance_scale=7.5)
pipe.update_components(guider=guider)
inputs = self.get_dummy_inputs()
out_cfg = pipe(**inputs, output="images")

assert out_cfg.shape == out_no_cfg.shape
max_diff = np.abs(out_cfg - out_no_cfg).max()
assert max_diff > tol, "Output with CFG must be different from normal inference"


class TestQwenImageModularPipelineFast(ModularPipelineTesterMixin, QwenImageModularGuiderMixin):
pipeline_class = QwenImageModularPipeline
pipeline_blocks_class = QwenImageAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-modular"

params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])

def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"max_sequence_length": 16,
"output_type": "pt",
}
return inputs


class TestQwenImageEditModularPipelineFast(ModularPipelineTesterMixin, QwenImageModularGuiderMixin):
pipeline_class = QwenImageEditModularPipeline
pipeline_blocks_class = QwenImageEditAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-modular"

params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image", "mask_image"])
batch_params = frozenset(["prompt", "negative_prompt", "image", "mask_image"])

def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs

def test_guider_cfg(self):
super().test_guider_cfg(7e-5)


class TestQwenImageEditPlusModularPipelineFast(ModularPipelineTesterMixin, QwenImageModularGuiderMixin):
pipeline_class = QwenImageEditPlusModularPipeline
pipeline_blocks_class = QwenImageEditPlusAutoBlocks
repo = "hf-internal-testing/tiny-qwenimage-edit-plus-modular"

# No `mask_image` yet.
params = frozenset(["prompt", "height", "width", "negative_prompt", "attention_kwargs", "image"])
batch_params = frozenset(["prompt", "negative_prompt", "image"])

def get_dummy_inputs(self):
generator = self.get_generator()
inputs = {
"prompt": "dance monkey",
"negative_prompt": "bad quality",
"generator": generator,
"num_inference_steps": 2,
"height": 32,
"width": 32,
"output_type": "pt",
}
inputs["image"] = PIL.Image.new("RGB", (32, 32), 0)
return inputs

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_num_images_per_prompt(self):
super().test_num_images_per_prompt()

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_consistent():
super().test_inference_batch_consistent()

@pytest.mark.xfail(condition=True, reason="Batch of multiple images needs to be revisited", strict=True)
def test_inference_batch_single_identical():
super().test_inference_batch_single_identical()
Comment on lines +128 to +138
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These are skipped in the standard pipeline tests, too.


def test_guider_cfg(self):
super().test_guider_cfg(1e-3)
33 changes: 8 additions & 25 deletions tests/modular_pipelines/test_modular_pipelines_common.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,22 +2,16 @@
import tempfile
from typing import Callable, Union

import pytest
import torch

import diffusers
from diffusers import ComponentsManager, ModularPipeline, ModularPipelineBlocks
from diffusers.utils import logging

from ..testing_utils import (
backend_empty_cache,
numpy_cosine_similarity_distance,
require_accelerator,
require_torch,
torch_device,
)
from ..testing_utils import backend_empty_cache, numpy_cosine_similarity_distance, torch_device


@require_torch
class ModularPipelineTesterMixin:
"""
It provides a set of common tests for each modular pipeline,
Expand All @@ -32,20 +26,9 @@ class ModularPipelineTesterMixin:
# Canonical parameters that are passed to `__call__` regardless
# of the type of pipeline. They are always optional and have common
# sense default values.
optional_params = frozenset(
[
"num_inference_steps",
"num_images_per_prompt",
"latents",
"output_type",
]
)
optional_params = frozenset(["num_inference_steps", "num_images_per_prompt", "latents", "output_type"])
# this is modular specific: generator needs to be a intermediate input because it's mutable
intermediate_params = frozenset(
[
"generator",
]
)
intermediate_params = frozenset(["generator"])

def get_generator(self, seed=0):
generator = torch.Generator("cpu").manual_seed(seed)
Expand Down Expand Up @@ -215,7 +198,7 @@ def test_inference_batch_single_identical(
max_diff = torch.abs(output_batch[0] - output[0]).max()
assert max_diff < expected_max_diff, "Batch inference results different from single inference results"

@require_accelerator
@pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_float16_inference(self, expected_max_diff=5e-2):
pipe = self.get_pipeline()
pipe.to(torch_device, torch.float32)
Expand Down Expand Up @@ -244,7 +227,7 @@ def test_float16_inference(self, expected_max_diff=5e-2):
max_diff = numpy_cosine_similarity_distance(output.flatten(), output_fp16.flatten())
assert max_diff < expected_max_diff, "FP16 inference is different from FP32 inference"

@require_accelerator
@pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_to_device(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
Expand All @@ -271,7 +254,7 @@ def test_inference_is_not_nan_cpu(self):
output = pipe(**self.get_dummy_inputs(), output="images")
assert torch.isnan(output).sum() == 0, "CPU Inference returns NaN"

@require_accelerator
@pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_inference_is_not_nan(self):
pipe = self.get_pipeline()
pipe.set_progress_bar_config(disable=None)
Expand Down Expand Up @@ -304,7 +287,7 @@ def test_num_images_per_prompt(self):

assert images.shape[0] == batch_size * num_images_per_prompt

@require_accelerator
@pytest.mark.skipif(torch_device == "cpu", reason="Test needs an accelerator.")
def test_components_auto_cpu_offload_inference_consistent(self):
base_pipe = self.get_pipeline().to(torch_device)

Expand Down
Loading