|
14 | 14 |
|
15 | 15 | import gc |
16 | 16 | import unittest |
| 17 | +from unittest.mock import patch |
17 | 18 |
|
18 | 19 | import torch |
19 | 20 |
|
20 | 21 | from diffusers.hooks import HookRegistry, ModelHook |
| 22 | +from diffusers.hooks.context_parallel import ContextParallelSplitHook, EquipartitionSharder |
21 | 23 | from diffusers.training_utils import free_memory |
22 | 24 | from diffusers.utils.logging import get_logger |
23 | 25 |
|
@@ -62,6 +64,28 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: |
62 | 64 | return x |
63 | 65 |
|
64 | 66 |
|
| 67 | +# Small helpers to simulate the parallel_config._flattened_mesh used by the hook |
| 68 | +class _DummyMesh: |
| 69 | + def __init__(self, size: int): |
| 70 | + self._size = size |
| 71 | + |
| 72 | + def size(self): |
| 73 | + return self._size |
| 74 | + |
| 75 | + |
| 76 | +class _DummyParallelConfig: |
| 77 | + def __init__(self, mesh_size: int): |
| 78 | + self._flattened_mesh = _DummyMesh(mesh_size) |
| 79 | + |
| 80 | + |
| 81 | +# Lightweight object that behaves like a ContextParallelInput for testing. |
| 82 | +class _DummyCPInput: |
| 83 | + def __init__(self, split_dim: int, expected_dims: int = None, split_output: bool = False): |
| 84 | + self.split_dim = split_dim |
| 85 | + self.expected_dims = expected_dims |
| 86 | + self.split_output = split_output |
| 87 | + |
| 88 | + |
65 | 89 | class AddHook(ModelHook): |
66 | 90 | def __init__(self, value: int): |
67 | 91 | super().__init__() |
@@ -375,3 +399,75 @@ def test_invocation_order_stateful_last(self): |
375 | 399 | .replace("\n", "") |
376 | 400 | ) |
377 | 401 | self.assertEqual(output, expected_invocation_order_log) |
| 402 | + |
| 403 | + |
| 404 | +class ContextParallelHooksTests(unittest.TestCase): |
| 405 | + def setUp(self): |
| 406 | + # world_size 3 will force padding for seq_len that isn't divisible by 3 |
| 407 | + self.parallel_config = _DummyParallelConfig(mesh_size=3) |
| 408 | + # metadata may be empty for our direct call to _prepare_cp_input |
| 409 | + self.hook = ContextParallelSplitHook(metadata={}, parallel_config=self.parallel_config) |
| 410 | + self.module = DummyModel(in_features=1, hidden_features=1, out_features=1, num_layers=1) |
| 411 | + # initialize_hook builds module_forward_metadata inside the hook |
| 412 | + self.hook.initialize_hook(self.module) |
| 413 | + # attach forward metadata to the module exactly how HookRegistry would do |
| 414 | + self.module._forward_metadata = self.hook.module_forward_metadata |
| 415 | + |
| 416 | + def test_prepare_cp_input_pads_hidden_states_and_stores_original(self): |
| 417 | + # create a tensor with seq_len = 7 along dim=1 (batch, seq, hidden) |
| 418 | + x = torch.randn(1, 7, 16) |
| 419 | + |
| 420 | + cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False) |
| 421 | + |
| 422 | + # Patch shard to identity so we can inspect the padded tensor directly |
| 423 | + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t) as mock_shard: |
| 424 | + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") |
| 425 | + |
| 426 | + # The hook should have padded seq_len from 7 -> 9 since world_size=3 |
| 427 | + self.assertEqual(out.shape[1], 9) |
| 428 | + |
| 429 | + # ensure shard was called once with the expected dim and mesh |
| 430 | + mock_shard.assert_called_once() |
| 431 | + called_args, _ = mock_shard.call_args |
| 432 | + # called_args = (tensor, dim, mesh) |
| 433 | + self.assertEqual(called_args[1], cp_input.split_dim) |
| 434 | + self.assertIs(called_args[2], self.parallel_config._flattened_mesh) |
| 435 | + |
| 436 | + # The hook should have recorded the original sequence length and pad dim |
| 437 | + # on the module's metadata so the gather hook can later trim. |
| 438 | + self.assertTrue(hasattr(self.module._forward_metadata, "_cp_original_s")) |
| 439 | + self.assertTrue(hasattr(self.module._forward_metadata, "_cp_pad_dim")) |
| 440 | + self.assertEqual(self.module._forward_metadata._cp_original_s, 7) |
| 441 | + self.assertEqual(self.module._forward_metadata._cp_pad_dim, 1) |
| 442 | + |
| 443 | + def test_prepare_cp_input_pads_attention_mask_with_zeros(self): |
| 444 | + # attention masks are typically shape (batch, seq) |
| 445 | + # create seq_len = 7 mask with ones |
| 446 | + mask = torch.ones(1, 7, dtype=torch.long) |
| 447 | + |
| 448 | + cp_input = _DummyCPInput(split_dim=1, expected_dims=2, split_output=False) |
| 449 | + |
| 450 | + # Patch shard to identity |
| 451 | + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): |
| 452 | + out_mask = self.hook._prepare_cp_input(mask, cp_input, name="encoder_attention_mask") |
| 453 | + |
| 454 | + # After padding it should be shape (1, 9) |
| 455 | + self.assertEqual(out_mask.shape[1], 9) |
| 456 | + # The padded values should be zeros (pad_value used in code for masks) |
| 457 | + # Check the last two positions are zero |
| 458 | + padded_portion = out_mask[:, -2:] |
| 459 | + self.assertTrue(torch.equal(padded_portion, torch.zeros_like(padded_portion))) |
| 460 | + |
| 461 | + def test_prepare_cp_input_no_pad_when_divisible(self): |
| 462 | + # seq_len is already divisible by world_size (3), e.g., 6 |
| 463 | + x = torch.randn(1, 6, 16) |
| 464 | + cp_input = _DummyCPInput(split_dim=1, expected_dims=3, split_output=False) |
| 465 | + |
| 466 | + with patch.object(EquipartitionSharder, "shard", side_effect=lambda t, dim, mesh: t): |
| 467 | + out = self.hook._prepare_cp_input(x, cp_input, name="hidden_states") |
| 468 | + |
| 469 | + # no padding should be performed |
| 470 | + self.assertEqual(out.shape[1], 6) |
| 471 | + # and no _cp_original_s/_cp_pad_dim set because not padded |
| 472 | + self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_original_s")) |
| 473 | + self.assertFalse(hasattr(self.hook.module_forward_metadata, "_cp_pad_dim")) |
0 commit comments