Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 31 additions & 1 deletion src/diffusers/modular_pipelines/qwenimage/before_denoise.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,7 +610,6 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
block_state = self.get_block_state(state)

# for edit, image size can be different from the target size (height/width)

block_state.img_shapes = [
[
(
Expand Down Expand Up @@ -640,6 +639,37 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state


class QwenImageEditPlusRoPEInputsStep(QwenImageEditRoPEInputsStep):
model_name = "qwenimage-edit-plus"

def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

vae_scale_factor = components.vae_scale_factor
block_state.img_shapes = [
[
(1, block_state.height // vae_scale_factor // 2, block_state.width // vae_scale_factor // 2),
*[
(1, vae_height // vae_scale_factor // 2, vae_width // vae_scale_factor // 2)
for vae_height, vae_width in zip(block_state.image_height, block_state.image_width)
],
]
] * block_state.batch_size

block_state.txt_seq_lens = (
block_state.prompt_embeds_mask.sum(dim=1).tolist() if block_state.prompt_embeds_mask is not None else None
)
block_state.negative_txt_seq_lens = (
block_state.negative_prompt_embeds_mask.sum(dim=1).tolist()
if block_state.negative_prompt_embeds_mask is not None
else None
)

self.set_block_state(state, block_state)

return components, state


## ControlNet inputs for denoiser
class QwenImageControlNetBeforeDenoiserStep(ModularPipelineBlocks):
model_name = "qwenimage"
Expand Down
82 changes: 65 additions & 17 deletions src/diffusers/modular_pipelines/qwenimage/encoders.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,7 @@ def __init__(
output_name: str = "resized_image",
vae_image_output_name: str = "vae_image",
):
"""Create a configurable step for resizing images to the target area (1024 * 1024) while maintaining the aspect ratio.
"""Create a configurable step for resizing images to the target area (384 * 384) while maintaining the aspect ratio.

This block resizes an input image or a list input images and exposes the resized result under configurable
input and output names. Use this when you need to wire the resize step to different image fields (e.g.,
Expand Down Expand Up @@ -803,9 +803,7 @@ def inputs(self) -> List[InputParam]:

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="processed_image"),
]
return [OutputParam(name="processed_image")]

@staticmethod
def check_inputs(height, width, vae_scale_factor):
Expand Down Expand Up @@ -845,7 +843,10 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):

class QwenImageEditPlusProcessImagesInputStep(QwenImageProcessImagesInputStep):
model_name = "qwenimage-edit-plus"
vae_image_size = 1024 * 1024

def __init__(self):
self.vae_image_size = 1024 * 1024
super().__init__()

@property
def description(self) -> str:
Expand All @@ -862,6 +863,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
if block_state.vae_image is None and block_state.image is None:
raise ValueError("`vae_image` and `image` cannot be None at the same time")

vae_image_sizes = None
if block_state.vae_image is None:
image = block_state.image
self.check_inputs(
Expand All @@ -873,12 +875,19 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState):
image=image, height=height, width=width
)
else:
width, height = block_state.vae_image[0].size
image = block_state.vae_image
# QwenImage Edit Plus can allow multiple input images with varied resolutions
processed_images = []
vae_image_sizes = []
for img in block_state.vae_image:
width, height = img.size
vae_width, vae_height, _ = calculate_dimensions(self.vae_image_size, width / height)
vae_image_sizes.append((vae_width, vae_height))
processed_images.append(
components.image_processor.preprocess(image=img, height=vae_height, width=vae_width)
)
block_state.processed_image = processed_images

block_state.processed_image = components.image_processor.preprocess(
image=image, height=height, width=width
)
block_state.vae_image_sizes = vae_image_sizes

self.set_block_state(state, block_state)
return components, state
Expand Down Expand Up @@ -920,17 +929,12 @@ def description(self) -> str:

@property
def expected_components(self) -> List[ComponentSpec]:
components = [
ComponentSpec("vae", AutoencoderKLQwenImage),
]
components = [ComponentSpec("vae", AutoencoderKLQwenImage)]
return components

@property
def inputs(self) -> List[InputParam]:
inputs = [
InputParam(self._image_input_name, required=True),
InputParam("generator"),
]
inputs = [InputParam(self._image_input_name, required=True), InputParam("generator")]
return inputs

@property
Expand Down Expand Up @@ -968,6 +972,50 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state


class QwenImageEditPlusVaeEncoderDynamicStep(QwenImageVaeEncoderDynamicStep):
model_name = "qwenimage-edit-plus"

@property
def intermediate_outputs(self) -> List[OutputParam]:
# Each reference image latent can have varied resolutions hence we return this as a list.
return [
OutputParam(
self._image_latents_output_name,
type_hint=List[torch.Tensor],
description="The latents representing the reference image(s).",
)
]

@torch.no_grad()
def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

device = components._execution_device
dtype = components.vae.dtype

image = getattr(block_state, self._image_input_name)

# Encode image into latents
image_latents = []
for img in image:
image_latents.append(
encode_vae_image(
image=img,
vae=components.vae,
generator=block_state.generator,
device=device,
dtype=dtype,
latent_channels=components.num_channels_latents,
)
)

setattr(block_state, self._image_latents_output_name, image_latents)

self.set_block_state(state, block_state)

return components, state


class QwenImageControlNetVaeEncoderStep(ModularPipelineBlocks):
model_name = "qwenimage"

Expand Down
76 changes: 71 additions & 5 deletions src/diffusers/modular_pipelines/qwenimage/inputs.py
Original file line number Diff line number Diff line change
Expand Up @@ -224,11 +224,7 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
class QwenImageInputsDynamicStep(ModularPipelineBlocks):
model_name = "qwenimage"

def __init__(
self,
image_latent_inputs: List[str] = ["image_latents"],
additional_batch_inputs: List[str] = [],
):
def __init__(self, image_latent_inputs: List[str] = ["image_latents"], additional_batch_inputs: List[str] = []):
"""Initialize a configurable step that standardizes the inputs for the denoising step. It:\n"
This step handles multiple common tasks to prepare inputs for the denoising step:
Expand Down Expand Up @@ -372,6 +368,76 @@ def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -
return components, state


class QwenImageEditPlusInputsDynamicStep(QwenImageInputsDynamicStep):
model_name = "qwenimage-edit-plus"

@property
def intermediate_outputs(self) -> List[OutputParam]:
return [
OutputParam(name="image_height", type_hint=List[int], description="The height of the image latents"),
OutputParam(name="image_width", type_hint=List[int], description="The width of the image latents"),
]

def __call__(self, components: QwenImageModularPipeline, state: PipelineState) -> PipelineState:
block_state = self.get_block_state(state)

# Process image latent inputs (height/width calculation, patchify, and batch expansion)
for image_latent_input_name in self._image_latent_inputs:
image_latent_tensor = getattr(block_state, image_latent_input_name)
if image_latent_tensor is None:
continue

# Each image latent can have different size in QwenImage Edit Plus.
image_heights = []
image_widths = []
packed_image_latent_tensors = []

for img_latent_tensor in image_latent_tensor:
# 1. Calculate height/width from latents
height, width = calculate_dimension_from_latents(img_latent_tensor, components.vae_scale_factor)
image_heights.append(height)
image_widths.append(width)

# 2. Patchify the image latent tensor
img_latent_tensor = components.pachifier.pack_latents(img_latent_tensor)

# 3. Expand batch size
img_latent_tensor = repeat_tensor_to_batch_size(
input_name=image_latent_input_name,
input_tensor=img_latent_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)
packed_image_latent_tensors.append(img_latent_tensor)

packed_image_latent_tensors = torch.cat(packed_image_latent_tensors, dim=1)
block_state.image_height = image_heights
block_state.image_width = image_widths
setattr(block_state, image_latent_input_name, packed_image_latent_tensors)

block_state.height = block_state.height or image_heights[-1]
block_state.width = block_state.width or image_widths[-1]

# Process additional batch inputs (only batch expansion)
for input_name in self._additional_batch_inputs:
input_tensor = getattr(block_state, input_name)
if input_tensor is None:
continue

# Only expand batch size
input_tensor = repeat_tensor_to_batch_size(
input_name=input_name,
input_tensor=input_tensor,
num_images_per_prompt=block_state.num_images_per_prompt,
batch_size=block_state.batch_size,
)

setattr(block_state, input_name, input_tensor)

self.set_block_state(state, block_state)
return components, state


class QwenImageControlNetInputsStep(ModularPipelineBlocks):
model_name = "qwenimage"

Expand Down
Loading
Loading