diff --git a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py index fdec95dc506e..2ef67e79866f 100644 --- a/src/diffusers/modular_pipelines/qwenimage/before_denoise.py +++ b/src/diffusers/modular_pipelines/qwenimage/before_denoise.py @@ -610,7 +610,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - block_state = self.get_block_state(state) # for edit, image size can be different from the target size (height/width) - block_state.img_shapes = [ [ ( @@ -640,6 +639,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep): + model_name = "qwenimage-edit-plus" + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + vae_scale_factor = components.vae_scale_factor + block_state.img_shapes = [ + [ + (1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2), + *[ + (1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2) + for vae_height, vae_width in zip(block_state.image_height, block_state.image_width) + ], + ] + ] * block_state.batch_size + + block_state.txt_seq_lens = ( + block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None + ) + block_state.negative_txt_seq_lens = ( + block_state.negative_prompt_embeds_mask.sum(dim=1).tolist() + if block_state.negative_prompt_embeds_mask is not None + else None + ) + + self.set_block_state(state, block_state) + + return components, state + + ## ControlNet inputs for denoiser class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/encoders.py b/src/diffusers/modular_pipelines/qwenimage/encoders.py index 04fb3fdc947b..278c31d83f27 100644 --- a/src/diffusers/modular_pipelines/qwenimage/encoders.py +++ b/src/diffusers/modular_pipelines/qwenimage/encoders.py @@ -330,7 +330,7 @@ def __init__( output_name: str = "resized_image", vae_image_output_name: str = "vae_image", ): - """Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio. + """Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio. This block resizes an input image or a list input images and exposes the resized result under configurable input and output names. Use this when you need to wire the resize step to different image fields (e.g., @@ -803,9 +803,7 @@ def inputs(self) -> List[InputParam]: @property def intermediate_outputs(self) -> List[OutputParam]: - return [ - OutputParam(name="processed_image"), - ] + return [OutputParam(name="processed_image")] @staticmethod def check_inputs(height, width, vae_scale_factor): @@ -845,7 +843,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep): model_name = "qwenimage-edit-plus" - vae_image_size = 1024 * 1024 + + def __init__(self): + self.vae_image_size = 1024 * 1024 + super().__init__() @property def description(self) -> str: @@ -862,6 +863,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): if block_state.vae_image is None and block_state.image is None: raise ValueError("`vae_image` and `image` cannot be None at the same time") + vae_image_sizes = None if block_state.vae_image is None: image = block_state.image self.check_inputs( @@ -873,12 +875,19 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState): image=image, height=height, width=width ) else: - width, height = block_state.vae_image[0].size - image = block_state.vae_image + # QwenImage Edit Plus can allow multiple input images with varied resolutions + processed_images = [] + vae_image_sizes = [] + for img in block_state.vae_image: + width, height = img.size + vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height) + vae_image_sizes.append((vae_width, vae_height)) + processed_images.append( + components.image_processor.preprocess(image=img, height=vae_height, width=vae_width) + ) + block_state.processed_image = processed_images - block_state.processed_image = components.image_processor.preprocess( - image=image, height=height, width=width - ) + block_state.vae_image_sizes = vae_image_sizes self.set_block_state(state, block_state) return components, state @@ -920,17 +929,12 @@ def description(self) -> str: @property def expected_components(self) -> List[ComponentSpec]: - components = [ - ComponentSpec("vae", AutoencoderKLQwenImage), - ] + components = [ComponentSpec("vae", AutoencoderKLQwenImage)] return components @property def inputs(self) -> List[InputParam]: - inputs = [ - InputParam(self._image_input_name, required=True), - InputParam("generator"), - ] + inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")] return inputs @property @@ -968,6 +972,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep): + model_name = "qwenimage-edit-plus" + + @property + def intermediate_outputs(self) -> List[OutputParam]: + # Each reference image latent can have varied resolutions hence we return this as a list. + return [ + OutputParam( + self._image_latents_output_name, + type_hint=List[torch.Tensor], + description="The latents representing the reference image(s).", + ) + ] + + @torch.no_grad() + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + device = components._execution_device + dtype = components.vae.dtype + + image = getattr(block_state, self._image_input_name) + + # Encode image into latents + image_latents = [] + for img in image: + image_latents.append( + encode_vae_image( + image=img, + vae=components.vae, + generator=block_state.generator, + device=device, + dtype=dtype, + latent_channels=components.num_channels_latents, + ) + ) + + setattr(block_state, self._image_latents_output_name, image_latents) + + self.set_block_state(state, block_state) + + return components, state + + class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/inputs.py b/src/diffusers/modular_pipelines/qwenimage/inputs.py index 2b229c040b89..6e656e484847 100644 --- a/src/diffusers/modular_pipelines/qwenimage/inputs.py +++ b/src/diffusers/modular_pipelines/qwenimage/inputs.py @@ -224,11 +224,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - class QwenImageInputsDynamicStep(ModularPipelineBlocks): model_name = "qwenimage" - def __init__( - self, - image_latent_inputs: List[str] = ["image_latents"], - additional_batch_inputs: List[str] = [], - ): + def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []): """Initialize a configurable step that standardizes the inputs for the denoising step. It:\n" This step handles multiple common tasks to prepare inputs for the denoising step: @@ -372,6 +368,76 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) - return components, state +class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep): + model_name = "qwenimage-edit-plus" + + @property + def intermediate_outputs(self) -> List[OutputParam]: + return [ + OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"), + OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"), + ] + + def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState: + block_state = self.get_block_state(state) + + # Process image latent inputs (height/width calculation, patchify, and batch expansion) + for image_latent_input_name in self._image_latent_inputs: + image_latent_tensor = getattr(block_state, image_latent_input_name) + if image_latent_tensor is None: + continue + + # Each image latent can have different size in QwenImage Edit Plus. + image_heights = [] + image_widths = [] + packed_image_latent_tensors = [] + + for img_latent_tensor in image_latent_tensor: + # 1. Calculate height/width from latents + height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor) + image_heights.append(height) + image_widths.append(width) + + # 2. Patchify the image latent tensor + img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor) + + # 3. Expand batch size + img_latent_tensor = repeat_tensor_to_batch_size( + input_name=image_latent_input_name, + input_tensor=img_latent_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + packed_image_latent_tensors.append(img_latent_tensor) + + packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1) + block_state.image_height = image_heights + block_state.image_width = image_widths + setattr(block_state, image_latent_input_name, packed_image_latent_tensors) + + block_state.height = block_state.height or image_heights[-1] + block_state.width = block_state.width or image_widths[-1] + + # Process additional batch inputs (only batch expansion) + for input_name in self._additional_batch_inputs: + input_tensor = getattr(block_state, input_name) + if input_tensor is None: + continue + + # Only expand batch size + input_tensor = repeat_tensor_to_batch_size( + input_name=input_name, + input_tensor=input_tensor, + num_images_per_prompt=block_state.num_images_per_prompt, + batch_size=block_state.batch_size, + ) + + setattr(block_state, input_name, input_tensor) + + self.set_block_state(state, block_state) + return components, state + + class QwenImageControlNetInputsStep(ModularPipelineBlocks): model_name = "qwenimage" diff --git a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py index 83bfcb3da4fd..a35106d50f4c 100644 --- a/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py +++ b/src/diffusers/modular_pipelines/qwenimage/modular_blocks.py @@ -18,6 +18,7 @@ from .before_denoise import ( QwenImageControlNetBeforeDenoiserStep, QwenImageCreateMaskLatentsStep, + QwenImageEditPlusRoPEInputsStep, QwenImageEditRoPEInputsStep, QwenImagePrepareLatentsStep, QwenImagePrepareLatentsWithStrengthStep, @@ -40,6 +41,7 @@ QwenImageEditPlusProcessImagesInputStep, QwenImageEditPlusResizeDynamicStep, QwenImageEditPlusTextEncoderStep, + QwenImageEditPlusVaeEncoderDynamicStep, QwenImageEditResizeDynamicStep, QwenImageEditTextEncoderStep, QwenImageInpaintProcessImagesInputStep, @@ -47,7 +49,12 @@ QwenImageTextEncoderStep, QwenImageVaeEncoderDynamicStep, ) -from .inputs import QwenImageControlNetInputsStep, QwenImageInputsDynamicStep, QwenImageTextInputsStep +from .inputs import ( + QwenImageControlNetInputsStep, + QwenImageEditPlusInputsDynamicStep, + QwenImageInputsDynamicStep, + QwenImageTextInputsStep, +) logger = logging.get_logger(__name__) @@ -905,13 +912,13 @@ def description(self) -> str: [ ("resize", QwenImageEditPlusResizeDynamicStep()), # edit plus has a different resize step ("preprocess", QwenImageEditPlusProcessImagesInputStep()), # vae_image -> processed_image - ("encode", QwenImageVaeEncoderDynamicStep()), # processed_image -> image_latents + ("encode", QwenImageEditPlusVaeEncoderDynamicStep()), # processed_image -> image_latents ] ) class QwenImageEditPlusVaeEncoderStep(SequentialPipelineBlocks): - model_name = "qwenimage" + model_name = "qwenimage-edit-plus" block_classes = QwenImageEditPlusVaeEncoderBlocks.values() block_names = QwenImageEditPlusVaeEncoderBlocks.keys() @@ -920,25 +927,62 @@ def description(self) -> str: return "Vae encoder step that encode the image inputs into their latent representations." +#### QwenImage Edit Plus input blocks +QwenImageEditPlusInputBlocks = InsertableDict( + [ + ("text_inputs", QwenImageTextInputsStep()), # default step to process text embeddings + ( + "additional_inputs", + QwenImageEditPlusInputsDynamicStep(image_latent_inputs=["image_latents"]), + ), + ] +) + + +class QwenImageEditPlusInputStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = QwenImageEditPlusInputBlocks.values() + block_names = QwenImageEditPlusInputBlocks.keys() + + #### QwenImage Edit Plus presets EDIT_PLUS_BLOCKS = InsertableDict( [ ("text_encoder", QwenImageEditPlusVLEncoderStep()), ("vae_encoder", QwenImageEditPlusVaeEncoderStep()), - ("input", QwenImageEditInputStep()), + ("input", QwenImageEditPlusInputStep()), ("prepare_latents", QwenImagePrepareLatentsStep()), ("set_timesteps", QwenImageSetTimestepsStep()), - ("prepare_rope_inputs", QwenImageEditRoPEInputsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), ("denoise", QwenImageEditDenoiseStep()), ("decode", QwenImageDecodeStep()), ] ) +QwenImageEditPlusBeforeDenoiseBlocks = InsertableDict( + [ + ("prepare_latents", QwenImagePrepareLatentsStep()), + ("set_timesteps", QwenImageSetTimestepsStep()), + ("prepare_rope_inputs", QwenImageEditPlusRoPEInputsStep()), + ] +) + + +class QwenImageEditPlusBeforeDenoiseStep(SequentialPipelineBlocks): + model_name = "qwenimage-edit-plus" + block_classes = QwenImageEditPlusBeforeDenoiseBlocks.values() + block_names = QwenImageEditPlusBeforeDenoiseBlocks.keys() + + @property + def description(self): + return "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step for edit task." + + # auto before_denoise step for edit tasks class QwenImageEditPlusAutoBeforeDenoiseStep(AutoPipelineBlocks): model_name = "qwenimage-edit-plus" - block_classes = [QwenImageEditBeforeDenoiseStep] + block_classes = [QwenImageEditPlusBeforeDenoiseStep] block_names = ["edit"] block_trigger_inputs = ["image_latents"] @@ -947,7 +991,7 @@ def description(self): return ( "Before denoise step that prepare the inputs (timesteps, latents, rope inputs etc.) for the denoise step.\n" + "This is an auto pipeline block that works for edit (img2img) task.\n" - + " - `QwenImageEditBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + + " - `QwenImageEditPlusBeforeDenoiseStep` (edit) is used when `image_latents` is provided and `processed_mask_image` is not provided.\n" + " - if `image_latents` is not provided, step will be skipped." ) @@ -956,9 +1000,7 @@ def description(self): class QwenImageEditPlusAutoVaeEncoderStep(AutoPipelineBlocks): - block_classes = [ - QwenImageEditPlusVaeEncoderStep, - ] + block_classes = [QwenImageEditPlusVaeEncoderStep] block_names = ["edit"] block_trigger_inputs = ["image"] @@ -975,10 +1017,25 @@ def description(self): ## 3.3 QwenImage-Edit/auto blocks & presets +class QwenImageEditPlusAutoInputStep(AutoPipelineBlocks): + block_classes = [QwenImageEditPlusInputStep] + block_names = ["edit"] + block_trigger_inputs = ["image_latents"] + + @property + def description(self): + return ( + "Input step that prepares the inputs for the edit denoising step.\n" + + " It is an auto pipeline block that works for edit task.\n" + + " - `QwenImageEditPlusInputStep` (edit) is used when `image_latents` is provided.\n" + + " - if `image_latents` is not provided, step will be skipped." + ) + + class QwenImageEditPlusCoreDenoiseStep(SequentialPipelineBlocks): model_name = "qwenimage-edit-plus" block_classes = [ - QwenImageEditAutoInputStep, + QwenImageEditPlusAutoInputStep, QwenImageEditPlusAutoBeforeDenoiseStep, QwenImageEditAutoDenoiseStep, ]