Skip to content

Conversation

@tomeras91
Copy link
Contributor

@tomeras91 tomeras91 commented Aug 20, 2025

Purpose

Fix a few bugs with prefill chunking for the mamba2 kernels.

Prefill chunking with valren batching support for the Mamba2 block was added in #10909, but it contained a few bugs relating to handling of initial states across mamba chunk boundaries. Some of these bugs were fixed recently in #21783 - specifically regarding the decay factor of the initial states. Yet, in cases where chunk boundaries and sequence boundaries don't align (a sequence changes in the middle of a chunk), the state passing kernel with initial states was still buggy. Namely, the computation of dA_cumsum was computed from the start of the mamba chunk instead of from the start of the current sequence. This PR fixes this.

Other changes:

Test Plan

Make sure all cases in tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunking pass. These cases fail on main.

Test Result

Tests pass


Essential Elements of an Effective PR Description Checklist
  • The purpose of the PR, such as "Fix some issue (link existing issues this PR will resolve)".
  • The test plan, such as providing test command.
  • The test results, such as pasting the results comparison before and after, or e2e results
  • (Optional) The necessary documentation update, such as updating supported_models.md and examples for a new model.

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…actor to change names for better readability)

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
@github-actions
Copy link

👋 Hi! Thank you for contributing to the vLLM project.

💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels.

Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run fastcheck CI which starts running only a small and essential subset of CI tests to quickly catch errors. You can run other CI tests on top of those by going to your fastcheck build on Buildkite UI (linked in the PR checks section) and unblock them. If you do not have permission to unblock, ping simon-mo or khluu to add you in our Buildkite org.

Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging.

To run CI, PR reviewers can either: Add ready label to the PR or enable auto-merge.

🚀

@mergify mergify bot added the v1 label Aug 20, 2025
Copy link
Contributor

@gemini-code-assist gemini-code-assist bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request addresses a bug in the prefill chunking mechanism for Mamba2 kernels, specifically correcting the calculation of dA_cumsum across chunk boundaries. The solution involves passing the complete dA_cumsum tensor and chunk offsets to the state passing kernel, allowing for accurate adjustments based on sequence boundaries. Additionally, the PR introduces a new unit test to validate this fix and enhances the documentation for related functions.

My review identified a critical issue within the bug fix implementation in _state_passing_fwd_kernel. The mask for loading chunk offsets incorrectly excludes the last logical chunk, potentially causing incorrect calculations. I have provided a code suggestion to rectify this.

tomeras91 and others added 2 commits August 20, 2025 22:37
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for finding this and fixing. Especially appreciate the effort to write better docstrings.

@fabianlim Could you also PTAL at this PR?

query_start_loc = [0, 5, 10]
chunk_size = 8
total_seqlens = 10
-> chunk_indices = [0, 1, 0]
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm a bit confused by this docstring. Shouldn't the second logical chunk in this example belong in the first physical chunk? E.g., shouldn't chunk_indices=[0,0,1] ?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good catch thanks! Fixed

itype,
device='cuda'):
device='cuda',
return_ref=True):
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a specific reason why we need to add return_ref here? Couldn't we just ignore the output when we don't need it? Or does ssd_minimal_discrete update some of its inputs in-place?

Copy link
Contributor Author

@tomeras91 tomeras91 Aug 31, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was my initial approach, but the problem is that ssd_minimal_discrete asserts the max sequence length is a multiple of the mamba chunk (aka block) size:

assert X.shape[1] % block_len == 0

This is an assumption I wanted to break in the new unittest, but I wanted to reuse the code to generate random inputs given a tuple of sequence lengths. Since I didn't really need the reference (pytorch) outputs, I decided to add this return_ref flag.

Copy link
Contributor

@fabianlim fabianlim Sep 1, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ic.. how about changing return_ref to return_naive_ref and then document the behavior of that flag

  • if return_naive_ref=True, we will use the navie implemenation ssd_minimal_discreteto compute and return the reference

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sounds good. Done

exhausted[i] = False


@pytest.mark.parametrize("chunk_size", [8, 256])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test looks like its replicating test_mamba_chunk_scan_cont_batch, what is the key difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key difference is that this test makes sure prefill chunking is working as expected, without using the pytorch reference implementation. Instead, it compares the kernel output without prefill chunking to concatenated outputs with prefill chunking. This is the most straight-forward way to verify that prefill chunking is working as expected.

Another crucial difference from test_mamba_chunk_scan_cont_batch is that this test tests cases where the sequence length is not a multiple of the mamba chunk size. In other words - cases where a sequence changes in the middle of a mamba chunk. These are the cases which currently fail on main, and require the fixes in this PR. These cases are also no supported in the pytorch implementation (see other discussion), so they can't be easily added to test_mamba_chunk_scan_cont_batch which compares kernel results with the reference pytorch implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thats for the explaination, in this case I suggest to put some documentation to explain how test_mamba_chunk_scan_cont_batch_prefill_chunking differs from the previous test, since this test if a little long and its hard to understanding by quick glance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. Done

Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Copy link
Contributor

@fabianlim fabianlim left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

my comments are addressed

Copy link
Member

@tdoublep tdoublep left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM - Thanks!

@tdoublep tdoublep enabled auto-merge (squash) September 8, 2025 09:15
@github-actions github-actions bot added the ready ONLY add when PR is ready to merge/full CI is needed label Sep 8, 2025
@tdoublep tdoublep merged commit e041314 into vllm-project:main Sep 8, 2025
50 checks passed
@tomeras91 tomeras91 deleted the fix-mamba2-prefill-chunking branch September 9, 2025 07:05
eicherseiji pushed a commit to eicherseiji/vllm that referenced this pull request Sep 9, 2025
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
skyloevil pushed a commit to skyloevil/vllm that referenced this pull request Sep 13, 2025
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
FeiDaLI pushed a commit to FeiDaLI/vllm that referenced this pull request Sep 25, 2025
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
netanel-haber added a commit to netanel-haber/sglang that referenced this pull request Sep 28, 2025
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 10, 2025
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
xuebwang-amd pushed a commit to xuebwang-amd/vllm that referenced this pull request Oct 24, 2025
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

ready ONLY add when PR is ready to merge/full CI is needed v1

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants