-
-
Notifications
You must be signed in to change notification settings - Fork 11.8k
[Bugfix] Mamba2 SSD varlen bug fix initstates decay, improve test, assert chunk pwr 2 #21783
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
|
👋 Hi! Thank you for contributing to the vLLM project. 💬 Join our developer Slack at https://slack.vllm.ai to discuss your PR in #pr-reviews, coordinate on features in #feat- channels, or join special interest groups in #sig- channels. Just a reminder: PRs would not trigger full CI run by default. Instead, it would only run Once the PR is approved and ready to go, your PR reviewer(s) can run CI to test the changes comprehensively before merging. To run CI, PR reviewers can either: Add 🚀 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Code Review
This pull request provides a well-reasoned and correctly implemented fix for a correctness bug in the Mamba2 SSD implementation, specifically addressing an issue with initial state decay in variable-length sequences. The core change in the Triton kernel logic appears sound and directly targets the described bug.
The pull request also improves the overall robustness and performance characteristics of the code by enforcing that chunk_size must be a power of two, accompanied by a clear assertion. The test suite has been thoughtfully updated to reflect these changes, with adjusted test cases, tighter tolerances, and a new test to validate the bug fix under scale. The changes are consistent and of high quality.
tdoublep
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
|
This pull request has merge conflicts that must be resolved before it can be |
959185a to
3526dbe
Compare
|
I rebased it and it still passes all tests |
tlrmchlsmth
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thank you!
Head branch was pushed to by a user without write access
|
Mamba2 SSD tests failed on GitHub but worked locally on an A100 and H100. I increased the tolerance slightly based on the GitHub test output. The tolerances are still less than the test_flash_attn.py tolerances. |
|
@RishiAstra Could you merge from main and retry? Some failures might have been fixed in main |
Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
|
@cyang49 done, also had rebased 3 days ago but it didn't fix the test failures then. Is it possible to merge with some test failures not related to the bug fix? |
|
Should be fine to merge |
…sert chunk pwr 2 (vllm-project#21783) Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com> Signed-off-by: Paul Pak <paulpak58@gmail.com>
…sert chunk pwr 2 (vllm-project#21783) Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com> Signed-off-by: Diego-Castan <diego.castan@ibm.com>
…sert chunk pwr 2 (vllm-project#21783) Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
…sert chunk pwr 2 (vllm-project#21783) Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
…sert chunk pwr 2 (vllm-project#21783) Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com> Signed-off-by: Xiao Yu <xiao.yu@amd.com>
…sert chunk pwr 2 (vllm-project#21783) Signed-off-by: Rishi Astra <40644327+RishiAstra@users.noreply.github.com>
Essential Elements of an Effective PR Description Checklist
supported_models.mdandexamplesfor a new model.Purpose
This PR fixes a correctness bug for varlen Mamba2 SSD (prefill).
Some of the existing tests required a high tolerance such as 50% rtol (see #21379), and a larger test of 768 tokens and 128 chunk_size failed (other tests seem to be only 64 tokens). These issues are fixed in this PR.
Here is an instinctive explanation of the bug:
dA_cs_m_boundary, causing a "later" (more decay) value to be read ifBLOCK_SIZE_M<chunk_sizedAwhich was offset (subtract) bydA_cs_m_boundaryand was therefore "earlier" (less decay)scale_mwas therefore off, representing less decayprev_statesorinitstateswould have too much influence on the current token due toscale_mNote: we require that chunk_size is a power of 2. The chunk_size does not affect correctness and can be tuned for performance, so it's not useful to have a non-power-of-2 chunk_size since Triton tensors would anyway need to be padded to powers of 2. One of the test cases was removed, and one was adjusted to use chunk_size=16 instead of 17.
After these changes, all tests pass. The
test_mamba_chunk_scan_cont_batchtests that used to have 50% rtol in some cases now all useatol, rtol = 5e-3, 5e-3and the correspondingifandTODOwere removed.Test Plan
Add new larger test size, remove a bad test, adjust a bad test, fix tolerances, and run:
pytest tests/kernels/mamba/test_mamba_ssm_ssd.pyTest Result
All tests pass
342 passed in 933.16s (0:15:33)(Optional) Documentation Update
None, but an assert will trigger if chunk_size is not a power of 2. This probably does not need to be documented, especially given the assertion message
chunk_size must be integer power of 2.