From c6b1283e8f8c10865151f0d8b6ea21a70a944fcd Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 6 Nov 2025 17:53:09 +0530 Subject: [PATCH 1/3] try to fix qwen edit plus multi images (modular) --- .../qwenimage/before_denoise.py | 40 ++++++++++- .../modular_pipelines/qwenimage/encoders.py | 44 +++++++----- .../modular_pipelines/qwenimage/inputs.py | 8 +++ .../qwenimage/modular_blocks.py | 69 ++++++++++++++++--- 4 files changed, 134 insertions(+), 27 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index fdec95dc506e..788556634c89 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -610,7 +610,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state = self.get_block_state(state) # for edit, image size can be different from the target size (height/width) - block_state.img_shapes = [ [ ( @@ -640,6 +639,45 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): + model_name = "qwenimage-edit-plus" + + @property + def inputs(self) -> List[InputParam]: + existing_inputs = super().inputs + current_inputs = [InputParam("vae_image_sizes", type_hint=List[Tuple[int, int]])] + return existing_inputs + current_inputs + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + print(f"{block_state=}") + + vae_scale_factor = components.vae_scale_factor + print(f"{block_state.vae_image_sizes=}") + block_state.img_shapes = [ + [ + (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), + *[ + (1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2) + for vae_width, vae_height in block_state.vae_image_sizes + ], + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + ## ControlNet inputs for denoiser class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 04fb3fdc947b..7df73badf82a 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Union +from typing import Dict, List, Optional, Tuple, Union import PIL import torch @@ -803,9 +803,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam(name="processed_image"), - ] + return [OutputParam(name="processed_image")] @staticmethod def check_inputs(height, width, vae_scale_factor): @@ -845,7 +843,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): model_name = "qwenimage-edit-plus" - vae_image_size = 1024 * 1024 + + def __init__(self): + self.vae_image_size = 1024 * 1024 + super().__init__() @property def description(self) -> str: @@ -855,6 +856,12 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] + @property + def intermediate_outputs(self): + existing_outputs = super().intermediate_outputs + current_outputs = [OutputParam("vae_image_sizes", type_hint=List[Tuple[int, int]])] + return existing_outputs + current_outputs + @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -862,6 +869,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): if block_state.vae_image is None and block_state.image is None: raise ValueError("`vae_image` and `image` cannot be None at the same time") + vae_image_sizes = None if block_state.vae_image is None: image = block_state.image self.check_inputs( @@ -873,12 +881,19 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): image=image, height=height, width=width ) else: - width, height = block_state.vae_image[0].size - image = block_state.vae_image + processed_images = [] + vae_image_sizes = [] + for img in block_state.vae_image: + width, height = img.size + vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height) + vae_image_sizes.append((vae_width, vae_height)) + processed_images.append( + components.image_processor.preprocess(image=img, height=vae_height, width=vae_width) + ) + block_state.processed_image = torch.stack(processed_images, dim=0).squeeze(1) - block_state.processed_image = components.image_processor.preprocess( - image=image, height=height, width=width - ) + block_state.vae_image_sizes = vae_image_sizes + print(f"{block_state.vae_image_sizes=}") self.set_block_state(state, block_state) return components, state @@ -920,17 +935,12 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: - components = [ - ComponentSpec("vae", AutoencoderKLQwenImage), - ] + components = [ComponentSpec("vae", AutoencoderKLQwenImage)] return components @property def inputs(self) -> List[InputParam]: - inputs = [ - InputParam(self._image_input_name, required=True), - InputParam("generator"), - ] + inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")] return inputs @property diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 2b229c040b89..61b9cd4e80c9 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -228,6 +228,7 @@ def __init__( self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = [], + reshape_to_seq_dim: bool = False, ): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" @@ -245,6 +246,9 @@ def __init__( Names of additional conditional input tensors to expand batch size. These tensors will only have their batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. Defaults to []. Examples: ["processed_mask_image"] + reshape_to_seq_dim: (bool, optional): + If the packed output should be reshaped along the sequence dimension. Example: `[2, 4096, 64]` => `[1, + 8192, 64]`. This is needed for QwenImage Edit Plus. Examples: # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() @@ -263,6 +267,7 @@ def __init__( self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs + self.reshape_to_seq_dim = reshape_to_seq_dim super().__init__() @property @@ -341,6 +346,9 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # 2. Patchify the image latent tensor image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) + if self.reshape_to_seq_dim: + channels = image_latent_tensor.shape[-1] + image_latent_tensor = image_latent_tensor.reshape(1, -1, channels) # 3. Expand batch size image_latent_tensor = repeat_tensor_to_batch_size( diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 83bfcb3da4fd..16e3931e2fd5 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -18,6 +18,7 @@ from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, + QwenImageEditPlusRoPEInputsStep, QwenImageEditRoPEInputsStep, QwenImagePrepareLatentsStep, QwenImagePrepareLatentsWithStrengthStep, @@ -911,7 +912,7 @@ def description(self) -> str: class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit-plus" block_classes = QwenImageEditPlusVaeEncoderBlocks.values() block_names = QwenImageEditPlusVaeEncoderBlocks.keys() @@ -920,25 +921,62 @@ def description(self) -> str: return "Vae encoder step that encode the image inputs into their latent representations." +#### QwenImage Edit Plus input blocks +QwenImageEditPlusInputBlocks = InsertableDict( + [ + ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings + ( + "additional_inputs", + QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], reshape_to_seq_dim=True), + ), + ] +) + + +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = QwenImageEditPlusInputBlocks.values() + block_names = QwenImageEditPlusInputBlocks.keys() + + #### QwenImage Edit Plus presets EDIT_PLUS_BLOCKS = InsertableDict( [ ("text_encoder", QwenImageEditPlusVLEncoderStep()), ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), - ("input", QwenImageEditInputStep()), + ("input", QwenImageEditPlusInputStep()), ("prepare_latents", QwenImagePrepareLatentsStep()), ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) +QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), + ] +) + + +class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values() + block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." + + # auto before_denoise step for edit tasks class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "qwenimage-edit-plus" - block_classes = [QwenImageEditBeforeDenoiseStep] + block_classes = [QwenImageEditPlusBeforeDenoiseStep] block_names = ["edit"] block_trigger_inputs = ["image_latents"] @@ -947,7 +985,7 @@ def description(self): return ( "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + "This is an auto pipeline block that works for edit (img2img) task.\n" - + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + + " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + " - if `image_latents` is not provided, step will be skipped." ) @@ -956,9 +994,7 @@ def description(self): class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [ - QwenImageEditPlusVaeEncoderStep, - ] + block_classes = [QwenImageEditPlusVaeEncoderStep] block_names = ["edit"] block_trigger_inputs = ["image"] @@ -975,10 +1011,25 @@ def description(self): ## 3.3 QwenImage-Edit/auto blocks & presets +class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks): + block_classes = [QwenImageEditPlusInputStep] + block_names = ["edit"] + block_trigger_inputs = ["image_latents"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step.\n" + + " It is an auto pipeline block that works for edit task.\n" + + " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n" + + " - if `image_latents` is not provided, step will be skipped." + ) + + class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditAutoInputStep, + QwenImageEditPlusAutoInputStep, QwenImageEditPlusAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, ] From e13e3e44753e028f5269579f197c010ea8231682 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Thu, 6 Nov 2025 19:25:22 +0530 Subject: [PATCH 2/3] up --- src/diffusers/modular_pipelines/qwenimage/before_denoise.py | 4 +--- src/diffusers/modular_pipelines/qwenimage/encoders.py | 3 +-- 2 files changed, 2 insertions(+), 5 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 788556634c89..86ae32c32705 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -645,15 +645,13 @@ class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): @property def inputs(self) -> List[InputParam]: existing_inputs = super().inputs - current_inputs = [InputParam("vae_image_sizes", type_hint=List[Tuple[int, int]])] + current_inputs = [InputParam("vae_image_sizes", type_hint=List[Tuple[int, int]], required=True)] return existing_inputs + current_inputs def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) - print(f"{block_state=}") vae_scale_factor = components.vae_scale_factor - print(f"{block_state.vae_image_sizes=}") block_state.img_shapes = [ [ (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 7df73badf82a..1eb07e2d8d73 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -330,7 +330,7 @@ def __init__( output_name: str = "resized_image", vae_image_output_name: str = "vae_image", ): - """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + """Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio. This block resizes an input image or a list input images and exposes the resized result under configurable input and output names. Use this when you need to wire the resize step to different image fields (e.g., @@ -893,7 +893,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state.processed_image = torch.stack(processed_images, dim=0).squeeze(1) block_state.vae_image_sizes = vae_image_sizes - print(f"{block_state.vae_image_sizes=}") self.set_block_state(state, block_state) return components, state From 8f623826986ecb1dcc9ae34cfa4e0a9085a32397 Mon Sep 17 00:00:00 2001 From: sayakpaul Date: Mon, 10 Nov 2025 15:52:47 +0530 Subject: [PATCH 3/3] up --- .../qwenimage/before_denoise.py | 8 +- .../modular_pipelines/qwenimage/encoders.py | 55 ++++++++++-- .../modular_pipelines/qwenimage/inputs.py | 84 ++++++++++++++++--- .../qwenimage/modular_blocks.py | 12 ++- 4 files changed, 128 insertions(+), 31 deletions(-) diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index 86ae32c32705..2ef67e79866f 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -642,12 +642,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): model_name = "qwenimage-edit-plus" - @property - def inputs(self) -> List[InputParam]: - existing_inputs = super().inputs - current_inputs = [InputParam("vae_image_sizes", type_hint=List[Tuple[int, int]], required=True)] - return existing_inputs + current_inputs - def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: block_state = self.get_block_state(state) @@ -657,7 +651,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), *[ (1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2) - for vae_width, vae_height in block_state.vae_image_sizes + for vae_height, vae_width in zip(block_state.image_height, block_state.image_width) ], ] ] * block_state.batch_size diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 1eb07e2d8d73..278c31d83f27 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -12,7 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from typing import Dict, List, Optional, Tuple, Union +from typing import Dict, List, Optional, Union import PIL import torch @@ -856,12 +856,6 @@ def description(self) -> str: def inputs(self) -> List[InputParam]: return [InputParam("vae_image"), InputParam("image"), InputParam("height"), InputParam("width")] - @property - def intermediate_outputs(self): - existing_outputs = super().intermediate_outputs - current_outputs = [OutputParam("vae_image_sizes", type_hint=List[Tuple[int, int]])] - return existing_outputs + current_outputs - @torch.no_grad() def __call__(self, components: QwenImageModularPipeline, state: PipelineState): block_state = self.get_block_state(state) @@ -881,6 +875,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): image=image, height=height, width=width ) else: + # QwenImage Edit Plus can allow multiple input images with varied resolutions processed_images = [] vae_image_sizes = [] for img in block_state.vae_image: @@ -890,7 +885,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): processed_images.append( components.image_processor.preprocess(image=img, height=vae_height, width=vae_width) ) - block_state.processed_image = torch.stack(processed_images, dim=0).squeeze(1) + block_state.processed_image = processed_images block_state.vae_image_sizes = vae_image_sizes @@ -977,6 +972,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep): + model_name = "qwenimage-edit-plus" + + @property + def intermediate_outputs(self) -> List[OutputParam]: + # Each reference image latent can have varied resolutions hence we return this as a list. + return [ + OutputParam( + self._image_latents_output_name, + type_hint=List[torch.Tensor], + description="The latents representing the reference image(s).", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.vae.dtype + + image = getattr(block_state, self._image_input_name) + + # Encode image into latents + image_latents = [] + for img in image: + image_latents.append( + encode_vae_image( + image=img, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + ) + + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 61b9cd4e80c9..6e656e484847 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -224,12 +224,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - class QwenImageInputsDynamicStep(ModularPipelineBlocks): model_name = "qwenimage" - def __init__( - self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], - reshape_to_seq_dim: bool = False, - ): + def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" This step handles multiple common tasks to prepare inputs for the denoising step: @@ -246,9 +241,6 @@ def __init__( Names of additional conditional input tensors to expand batch size. These tensors will only have their batch dimensions adjusted to match the final batch size. Can be a single string or list of strings. Defaults to []. Examples: ["processed_mask_image"] - reshape_to_seq_dim: (bool, optional): - If the packed output should be reshaped along the sequence dimension. Example: `[2, 4096, 64]` => `[1, - 8192, 64]`. This is needed for QwenImage Edit Plus. Examples: # Configure to process image_latents (default behavior) QwenImageInputsDynamicStep() @@ -267,7 +259,6 @@ def __init__( self._image_latent_inputs = image_latent_inputs self._additional_batch_inputs = additional_batch_inputs - self.reshape_to_seq_dim = reshape_to_seq_dim super().__init__() @property @@ -346,9 +337,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - # 2. Patchify the image latent tensor image_latent_tensor = components.pachifier.pack_latents(image_latent_tensor) - if self.reshape_to_seq_dim: - channels = image_latent_tensor.shape[-1] - image_latent_tensor = image_latent_tensor.reshape(1, -1, channels) # 3. Expand batch size image_latent_tensor = repeat_tensor_to_batch_size( @@ -380,6 +368,76 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep): + model_name = "qwenimage-edit-plus" + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"), + OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # Each image latent can have different size in QwenImage Edit Plus. + image_heights = [] + image_widths = [] + packed_image_latent_tensors = [] + + for img_latent_tensor in image_latent_tensor: + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) + image_heights.append(height) + image_widths.append(width) + + # 2. Patchify the image latent tensor + img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) + + # 3. Expand batch size + img_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=img_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + packed_image_latent_tensors.append(img_latent_tensor) + + packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + block_state.image_height = image_heights + block_state.image_width = image_widths + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + block_state.height = block_state.height or image_heights[-1] + block_state.width = block_state.width or image_widths[-1] + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageControlNetInputsStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 16e3931e2fd5..a35106d50f4c 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -41,6 +41,7 @@ QwenImageEditPlusProcessImagesInputStep, QwenImageEditPlusResizeDynamicStep, QwenImageEditPlusTextEncoderStep, + QwenImageEditPlusVaeEncoderDynamicStep, QwenImageEditResizeDynamicStep, QwenImageEditTextEncoderStep, QwenImageInpaintProcessImagesInputStep, @@ -48,7 +49,12 @@ QwenImageTextEncoderStep, QwenImageVaeEncoderDynamicStep, ) -from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep +from .inputs import ( + QwenImageControlNetInputsStep, + QwenImageEditPlusInputsDynamicStep, + QwenImageInputsDynamicStep, + QwenImageTextInputsStep, +) logger = logging.get_logger(__name__) @@ -906,7 +912,7 @@ def description(self) -> str: [ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents ] ) @@ -927,7 +933,7 @@ def description(self) -> str: ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings ( "additional_inputs", - QwenImageInputsDynamicStep(image_latent_inputs=["image_latents"], reshape_to_seq_dim=True), + QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]), ), ] )