Skip to content

Conversation

@Ratish1
Copy link

@Ratish1 Ratish1 commented Nov 5, 2025

What does this PR do?

This PR now modifies the ContextParallelSplitHook and ContextParallelGatherHook to gracefully handle sequence lengths that are not divisible by the world size.

This PR changes:

  1. Generic Padding: The ContextParallelSplitHook now pads any input tensor to a divisible length before sharding.
  2. State Management: It temporarily stores the original sequence length on the module instance itself.
  3. Generic Trimming: The ContextParallelGatherHook uses this stored length to trim the padding from the final output tensor before returning it.

This ensures that the padding is completely transparent to the model and the end-user, preventing crashes without altering the output shape. The fix is now contained entirely within the hooks and requires no changes to the Qwen transformer or any other model.

I have also added a new unit test in tests/hooks/test_hooks.py that directly tests this new padding and trimming logic,

Fixes #12568

Before submitting

Who can review?

Anyone in the community is free to review the PR once the tests have passed. Feel free to tag
members/contributors who may be interested in your PR.
@sayakpaul @yiyixuxu @DN6

@Ratish1 Ratish1 changed the title fix(qwenimage): Correct context parallelism padding fix(qwenimage): Add padding for context parallelism Nov 5, 2025
@yiyixuxu
Copy link
Collaborator

yiyixuxu commented Nov 5, 2025

thanks for the PR! however, we will not want any of these logic go into qwen transformer
would you be interested to how to support this case( not just qwen) from the context parallel hooks
https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L204

@Ratish1
Copy link
Author

Ratish1 commented Nov 5, 2025

thanks for the PR! however, we will not want any of these logic go into qwen transformer would you be interested to how to support this case( not just qwen) from the context parallel hooks https://github.com/huggingface/diffusers/blob/main/src/diffusers/hooks/context_parallel.py#L204

Hi @yiyixuxu, yes I would be interested to support this change.

I have some follow up questions regarding the model-specific consequences of this padding. The Qwen transformer calculates RoPE based on the original sequence length. When the hook pads the input tensor, the model still needs to be aware of the new, padded length to avoid a shape mismatch inside the RoPE calculation.

Previously, I fixed this by recalculating the sequence length inside the transformer's forward method based on the padded tensor's shape. With the padding logic now in the hook, could you help me what is the preferred way to handle this?

Is it acceptable to keep that small, model-specific part of the logic (recalculating the sequence length for RoPE) inside the Qwen transformer, or is there a more general way to communicate the new padded length from the hook back to the model that I should use instead?. Thanks for your help.

@Ratish1
Copy link
Author

Ratish1 commented Nov 8, 2025

Hi @yiyixuxu ,Just wanted to follow up. After looking at the hook implementation as you suggested, I've updated the PR with a new approach that is fully generic and contains all logic within the hooks, with no changes to the transformer.

The solution now involves adding padding in the ContextParallelSplitHook and then trimming it in theContextParallelGatherHook, using the module instance to temporarily store the original sequence length. I've also added a new unit test for this logic in test_hooks.py. Thanks and lmk if you need more changes. I've updated the PR description with the full details.

CC @sayakpaul @DN6

@Ratish1 Ratish1 changed the title fix(qwenimage): Add padding for context parallelism fix(hooks): Add padding support to context parallel hooks Nov 8, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Context parallel bug when using QwenImage

2 participants