Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
50 changes: 45 additions & 5 deletions src/diffusers/hooks/context_parallel.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ def pre_forward(self, module, *args, **kwargs):
# The input_val may be a tensor or list/tuple of tensors. In certain cases, user may specify to shard
# the output instead of input for a particular layer by setting split_output=True
if isinstance(input_val, torch.Tensor):
input_val = self._prepare_cp_input(input_val, cpm)
input_val = self._prepare_cp_input(input_val, cpm, name)
elif isinstance(input_val, (list, tuple)):
if len(input_val) != len(cpm):
raise ValueError(
Expand Down Expand Up @@ -201,14 +201,41 @@ def post_forward(self, module, output):

return output[0] if is_tensor else tuple(output)

def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput) -> torch.Tensor:
def _prepare_cp_input(self, x: torch.Tensor, cp_input: ContextParallelInput, name: str) -> torch.Tensor:
if cp_input.expected_dims is not None and x.dim() != cp_input.expected_dims:
logger.warning_once(
f"Expected input tensor to have {cp_input.expected_dims} dimensions, but got {x.dim()} dimensions, split will not be applied."
)
return x
else:
return EquipartitionSharder.shard(x, cp_input.split_dim, self.parallel_config._flattened_mesh)

mesh = self.parallel_config._flattened_mesh
world_size = mesh.size()
dim = cp_input.split_dim
seq_len = x.size(dim)

if world_size > 1 and seq_len % world_size != 0:
pad_len = (world_size - seq_len % world_size) % world_size

# Determine pad_value based on name convention
if "mask" in name.lower():
pad_value = 0
else:
pad_value = 0.0

pad_width = [0] * (2 * x.dim())
# The pad_width tuple is read from last dim to first dim
pad_idx = x.dim() - 1 - dim
# We want to pad the right side
pad_width[2 * pad_idx + 1] = pad_len

x = torch.nn.functional.pad(x, tuple(pad_width), mode="constant", value=pad_value)

# Store original size for trimming in the gather hook
if "hidden_states" in name.lower():
self.module_forward_metadata._cp_original_s = seq_len
self.module_forward_metadata._cp_pad_dim = dim

return EquipartitionSharder.shard(x, dim, mesh)


class ContextParallelGatherHook(ModelHook):
Expand All @@ -233,7 +260,20 @@ def post_forward(self, module, output):
for i, cpm in enumerate(self.metadata):
if cpm is None:
continue
output[i] = EquipartitionSharder.unshard(output[i], cpm.gather_dim, self.parallel_config._flattened_mesh)

x = output[i]
x = EquipartitionSharder.unshard(x, cpm.gather_dim, self.parallel_config._flattened_mesh)

# Trim if padded info exists
if hasattr(module, "_forward_metadata") and hasattr(module._forward_metadata, "_cp_original_s"):
if cpm.gather_dim == module._forward_metadata._cp_pad_dim:
original_s = module._forward_metadata._cp_original_s
x = x.narrow(cpm.gather_dim, 0, original_s)
# Clean up the stored attributes
del module._forward_metadata._cp_original_s
del module._forward_metadata._cp_pad_dim

output[i] = x

return output[0] if is_tensor else tuple(output)

Expand Down
1 change: 1 addition & 0 deletions src/diffusers/models/transformers/transformer_qwenimage.py
Original file line number Diff line number Diff line change
Expand Up @@ -600,6 +600,7 @@ def forward(
If `return_dict` is True, an [`~models.transformer_2d.Transformer2DModelOutput`] is returned, otherwise a
`tuple` where the first element is the sample tensor.
"""

if attention_kwargs is not None:
attention_kwargs = attention_kwargs.copy()
lora_scale = attention_kwargs.pop("scale", 1.0)
Expand Down
96 changes: 96 additions & 0 deletions tests/hooks/test_hooks.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,12 @@

import gc
import unittest
from unittest.mock import patch

import torch

from diffusers.hooks import HookRegistry, ModelHook
from diffusers.hooks.context_parallel import ContextParallelSplitHook, EquipartitionSharder
from diffusers.training_utils import free_memory
from diffusers.utils.logging import get_logger

Expand Down Expand Up @@ -62,6 +64,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
return x


# Small helpers to simulate the parallel_config._flattened_mesh used by the hook
class _DummyMesh:
def __init__(self, size: int):
self._size = size

def size(self):
return self._size


class _DummyParallelConfig:
def __init__(self, mesh_size: int):
self._flattened_mesh = _DummyMesh(mesh_size)


# Lightweight object that behaves like a ContextParallelInput for testing.
class _DummyCPInput:
def __init__(self, split_dim: int, expected_dims: int = None, split_output: bool = False):
self.split_dim = split_dim
self.expected_dims = expected_dims
self.split_output = split_output


class AddHook(ModelHook):
def __init__(self, value: int):
super().__init__()
Expand Down Expand Up @@ -375,3 +399,75 @@ def test_invocation_order_stateful_last(self):
.replace("\n", "")
)
self.assertEqual(output, expected_invocation_order_log)


class ContextParallelHooksTests(unittest.TestCase):
def setUp(self):
# world_size 3 will force padding for seq_len that isn't divisible by 3
self.parallel_config = _DummyParallelConfig(mesh_size=3)
# metadata may be empty for our direct call to _prepare_cp_input
self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config)
self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1)
# initialize_hook builds module_forward_metadata inside the hook
self.hook.initialize_hook(self.module)
# attach forward metadata to the module exactly how HookRegistry would do
self.module._forward_metadata = self.hook.module_forward_metadata

def test_prepare_cp_input_pads_hidden_states_and_stores_original(self):
# create a tensor with seq_len = 7 along dim=1 (batch, seq, hidden)
x = torch.randn(1, 7, 16)

cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False)

# Patch shard to identity so we can inspect the padded tensor directly
with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t) as mock_shard:
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")

# The hook should have padded seq_len from 7 -> 9 since world_size=3
self.assertEqual(out.shape[1], 9)

# ensure shard was called once with the expected dim and mesh
mock_shard.assert_called_once()
called_args, _ = mock_shard.call_args
# called_args = (tensor, dim, mesh)
self.assertEqual(called_args[1], cp_input.split_dim)
self.assertIs(called_args[2], self.parallel_config._flattened_mesh)

# The hook should have recorded the original sequence length and pad dim
# on the module's metadata so the gather hook can later trim.
self.assertTrue(hasattr(self.module._forward_metadata, "_cp_original_s"))
self.assertTrue(hasattr(self.module._forward_metadata, "_cp_pad_dim"))
self.assertEqual(self.module._forward_metadata._cp_original_s, 7)
self.assertEqual(self.module._forward_metadata._cp_pad_dim, 1)

def test_prepare_cp_input_pads_attention_mask_with_zeros(self):
# attention masks are typically shape (batch, seq)
# create seq_len = 7 mask with ones
mask = torch.ones(1, 7, dtype=torch.long)

cp_input = _DummyCPInput(split_dim=1, expected_dims=2, split_output=False)

# Patch shard to identity
with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_attention_mask")

# After padding it should be shape (1, 9)
self.assertEqual(out_mask.shape[1], 9)
# The padded values should be zeros (pad_value used in code for masks)
# Check the last two positions are zero
padded_portion = out_mask[:, -2:]
self.assertTrue(torch.equal(padded_portion, torch.zeros_like(padded_portion)))

def test_prepare_cp_input_no_pad_when_divisible(self):
# seq_len is already divisible by world_size (3), e.g., 6
x = torch.randn(1, 6, 16)
cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False)

with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t):
out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states")

# no padding should be performed
self.assertEqual(out.shape[1], 6)
# and no _cp_original_s/_cp_pad_dim set because not padded
self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_original_s"))
self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_pad_dim"))