-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
[Bugfix] Fix mamba2 prefill chunking #23279
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Bugfix] Fix mamba2 prefill chunking #23279
Conversation
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
…actor to change names for better readability) Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request addresses a bug in the prefill chunking mechanism for Mamba2 kernels, specifically correcting the calculation of dA_cumsum across chunk boundaries. The solution involves passing the complete dA_cumsum tensor and chunk offsets to the state passing kernel, allowing for accurate adjustments based on sequence boundaries. Additionally, the PR introduces a new unit test to validate this fix and enhances the documentation for related functions.
My review identified a critical issue within the bug fix implementation in _state_passing_fwd_kernel. The mask for loading chunk offsets incorrectly excludes the last logical chunk, potentially causing incorrect calculations. I have provided a code suggestion to rectify this.
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com>
tdoublep
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for finding this and fixing. Especially appreciate the effort to write better docstrings.
@fabianlim Could you also PTAL at this PR?
| query_start_loc = [0, 5, 10] | ||
| chunk_size = 8 | ||
| total_seqlens = 10 | ||
| -> chunk_indices = [0, 1, 0] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm a bit confused by this docstring. Shouldn't the second logical chunk in this example belong in the first physical chunk? E.g., shouldn't chunk_indices=[0,0,1] ?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Good catch thanks! Fixed
| itype, | ||
| device='cuda'): | ||
| device='cuda', | ||
| return_ref=True): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a specific reason why we need to add return_ref here? Couldn't we just ignore the output when we don't need it? Or does ssd_minimal_discrete update some of its inputs in-place?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That was my initial approach, but the problem is that ssd_minimal_discrete asserts the max sequence length is a multiple of the mamba chunk (aka block) size:
| assert X.shape[1] % block_len == 0 |
This is an assumption I wanted to break in the new unittest, but I wanted to reuse the code to generate random inputs given a tuple of sequence lengths. Since I didn't really need the reference (pytorch) outputs, I decided to add this return_ref flag.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
ic.. how about changing return_ref to return_naive_ref and then document the behavior of that flag
- if
return_naive_ref=True, we will use the navie implemenationssd_minimal_discreteto compute and return the reference
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good. Done
| exhausted[i] = False | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("chunk_size", [8, 256]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
this test looks like its replicating test_mamba_chunk_scan_cont_batch, what is the key difference?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The key difference is that this test makes sure prefill chunking is working as expected, without using the pytorch reference implementation. Instead, it compares the kernel output without prefill chunking to concatenated outputs with prefill chunking. This is the most straight-forward way to verify that prefill chunking is working as expected.
Another crucial difference from test_mamba_chunk_scan_cont_batch is that this test tests cases where the sequence length is not a multiple of the mamba chunk size. In other words - cases where a sequence changes in the middle of a mamba chunk. These are the cases which currently fail on main, and require the fixes in this PR. These cases are also no supported in the pytorch implementation (see other discussion), so they can't be easily added to test_mamba_chunk_scan_cont_batch which compares kernel results with the reference pytorch implementation.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I see thats for the explaination, in this case I suggest to put some documentation to explain how test_mamba_chunk_scan_cont_batch_prefill_chunking differs from the previous test, since this test if a little long and its hard to understanding by quick glance.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
makes sense. Done
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com>
fabianlim
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
my comments are addressed
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM - Thanks!
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
Purpose
Fix a few bugs with prefill chunking for the mamba2 kernels.
Prefill chunking with valren batching support for the Mamba2 block was added in #10909, but it contained a few bugs relating to handling of initial states across mamba chunk boundaries. Some of these bugs were fixed recently in #21783 - specifically regarding the decay factor of the initial states. Yet, in cases where chunk boundaries and sequence boundaries don't align (a sequence changes in the middle of a chunk), the state passing kernel with initial states was still buggy. Namely, the computation of
dA_cumsumwas computed from the start of the mamba chunk instead of from the start of the current sequence. This PR fixes this.Other changes:
mamba_chunk_scan_combinedwith prefill chunking and varlen batching, comparing chunked results to those of the full sequence._query_start_loc_to_chunk_indices_offsets, which is a somewhat cryptic functionTest Plan
Make sure all cases in
tests/kernels/mamba/test_mamba_ssm_ssd.py::test_mamba_chunk_scan_cont_batch_prefill_chunkingpass. These cases fail onmain.Test Result
Tests pass
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.