-
-
Notifications
You must be signed in to change notification settings - Fork 11.7k
Add Bamba Model #10909
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Merged
Merged
Add Bamba Model #10909
Changes from 43 commits
Commits
Show all changes
78 commits
Select commit
Hold shift + click to select a range
62181d5
initial pr without tp fix
fabianlim 51bc78c
fix casting in rms norm gated
fabianlim 81b93b4
TP fix
fabianlim 0f93e4a
fix mamba scan invalid address
fabianlim 742ae79
some fixes and remove unused kernels
fabianlim b2dc5ca
fmt + lint
fabianlim 9ad9e20
more comments
fabianlim 25bf381
initial fix for chunked prefill (incomplete)
fabianlim 43ce07c
improve comments
fabianlim 80f14b5
do not attach seq_idx to attn_metadata
fabianlim 6b8ac49
activate initial states for chunked prefill
fabianlim d788db6
reuse softplus and remove triton2 remark
fabianlim 400db27
add comment on weight loader and format
fabianlim bda8ea7
rename test_jamba to test_hybrid and got rid of test_bamba
fabianlim 66078d6
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim a74de9f
update bamba to ishybrid and support pp
fabianlim b44caa7
lint
fabianlim 8cf3644
add unit test for mamba ssd
fabianlim e375b40
fix lint
fabianlim dcbae7b
full chunked-prefill fix (sans unit tests)
fabianlim 2597105
format and add cont batch unit tests (will need more cases)
fabianlim db5eea5
fix kernel tests and add more chunked prefill cases
fabianlim dfbcb16
bound adjustment
fabianlim 7913009
bound adjustment
fabianlim 9c5d045
lint errors
fabianlim 6bc9dac
Add permalink correction from @tlrmchlsmth
fabianlim 6d02e85
improved comment for segsum, add more sizes for test_mamba_chunk_scan…
fabianlim e5882f2
rename and comment functions, add more sizes for test_mamba_chunk_sca…
fabianlim 6d6fa86
addressed comments on mamba_mixer2.py
fabianlim 773dd80
replace with get_rope
fabianlim 63f5340
rope scaling
fabianlim 89e36d8
fixes
fabianlim 7a4ae96
zero out ssm states
fabianlim a9e149c
fix tests (sans updating dev checkpoint)
fabianlim 5c9f48d
not replacing dev model for now
fabianlim 55647b1
update requirements
fabianlim 2342bc0
remove extraneous comment
fabianlim 011c141
update test
fabianlim 503bc42
fix lint
fabianlim 312cf1d
fix lint
fabianlim c1db743
fix requirements-test
fabianlim c956a30
Mamba2 changes from #10909
tlrmchlsmth 17923ad
Get Mamba2 working!
tlrmchlsmth 4183d45
Add integration test -- something is wrong!!
tlrmchlsmth 5377644
format
tlrmchlsmth 39f55d1
fixes
tlrmchlsmth dd31f19
update test registry, fixes
fabianlim e2e5aac
Fix for conv state shape and update placeholder_attn
tlrmchlsmth bc1b8af
back out placeholder_attn changes
tlrmchlsmth 9db0dd5
make seq_idx to chunk indices more efficient
fabianlim cd89283
WIP debugging, restore local mamba and placeholder_attn changes
tlrmchlsmth 9a838a3
Integration tests are now green
tlrmchlsmth be8318e
remove bamba-specific files
tlrmchlsmth f34d434
Merge branch 'main' into tms/mamba2
tlrmchlsmth a65e2cb
Handle grouping in Mixer2RMSNormGated
tlrmchlsmth 0d4bb0f
debug cruft
tlrmchlsmth 74f6088
Remove codestral integration test
tlrmchlsmth 95583b8
Merge branch 'tms/mamba2' into bamba-pr
fabianlim b72389c
update mamba_cache
fabianlim 10d75eb
remove changes to requirements
fabianlim 5aea1e6
revert changes
fabianlim 2ee8d07
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim 043e006
fix lint
fabianlim 7e4ce4f
fix lint
fabianlim 8219480
more reverts
fabianlim 2a154e1
remove unnecessary stuff
fabianlim b0536f7
add mixer2 gated norm TP test
fabianlim b2e7952
Merge remote-tracking branch 'upstream/main' into bamba-pr
fabianlim 06c4e7f
add header
fabianlim 851239a
fix lint
fabianlim 6466c3c
Merge branch 'main' into bamba-pr
tlrmchlsmth 64f6a4e
checkpoint renames
fabianlim 266ce81
(debug) test_mamba_ssm_ssd.py
fabianlim 965620d
[debug] make all run same shard_id
fabianlim 4a846ab
[debug] disable test case
fabianlim da380b1
revert debugs and add @tlrmchlsmth fix!
fabianlim 51d3762
Merge branch 'main' into bamba-pr
tlrmchlsmth eba332a
update mamba and jamba for MambaCache changes
tlrmchlsmth File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
fabianlim marked this conversation as resolved.
Show resolved
Hide resolved
|
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,302 @@ | ||
| from typing import Dict, Tuple | ||
|
|
||
| import pytest | ||
| import torch | ||
| import torch.nn.functional as F | ||
| from einops import rearrange, repeat | ||
|
|
||
| from vllm.model_executor.layers.mamba.ops.ssd_combined import ( | ||
| mamba_chunk_scan_combined) | ||
| from vllm.platforms import current_platform | ||
|
|
||
| # Added by the IBM Team, 2024 | ||
|
|
||
| # Adapted from https://github.com/state-spaces/mamba/blob/v2.2.4/mamba_ssm/modules/ssd_minimal.py | ||
|
|
||
|
|
||
| # this is the segsum implementation taken from above | ||
| def segsum(x): | ||
| """Calculates segment sum.""" | ||
| T = x.size(-1) | ||
| x = repeat(x, "... d -> ... d e", e=T) | ||
| mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), | ||
| diagonal=-1) | ||
| x = x.masked_fill(~mask, 0) | ||
| x_segsum = torch.cumsum(x, dim=-2) | ||
| mask = torch.tril(torch.ones(T, T, device=x.device, dtype=bool), | ||
| diagonal=0) | ||
| x_segsum = x_segsum.masked_fill(~mask, -torch.inf) | ||
| return x_segsum | ||
|
|
||
|
|
||
| def ssd_minimal_discrete(X, A, B, C, block_len, initial_states=None): | ||
| """ | ||
| Arguments: | ||
| X: (batch, length, n_heads, d_head) | ||
| A: (batch, length, n_heads) | ||
| B: (batch, length, n_heads, d_state) | ||
| C: (batch, length, n_heads, d_state) | ||
| Return: | ||
| Y: (batch, length, n_heads, d_head) | ||
| """ | ||
| assert X.dtype == A.dtype == B.dtype == C.dtype | ||
| assert X.shape[1] % block_len == 0 | ||
|
|
||
| # Rearrange into blocks/chunks | ||
| X, A, B, C = (rearrange(x, "b (c l) ... -> b c l ...", l=block_len) | ||
| for x in (X, A, B, C)) | ||
|
|
||
| A = rearrange(A, "b c l h -> b h c l") | ||
| A_cumsum = torch.cumsum(A, dim=-1) | ||
|
|
||
| # 1. Compute the output for each intra-chunk (diagonal blocks) | ||
| L = torch.exp(segsum(A)) | ||
| Y_diag = torch.einsum("bclhn,bcshn,bhcls,bcshp->bclhp", C, B, L, X) | ||
|
|
||
| # 2. Compute the state for each intra-chunk | ||
| # (right term of low-rank factorization of off-diagonal blocks; B terms) | ||
| decay_states = torch.exp(A_cumsum[:, :, :, -1:] - A_cumsum) | ||
| states = torch.einsum("bclhn,bhcl,bclhp->bchpn", B, decay_states, X) | ||
|
|
||
| # 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at | ||
| # chunk boundaries | ||
| # (middle term of factorization of off-diag blocks; A terms) | ||
| if initial_states is None: | ||
| initial_states = torch.zeros_like(states[:, :1]) | ||
| states = torch.cat([initial_states, states], dim=1) | ||
| decay_chunk = torch.exp(segsum(F.pad(A_cumsum[:, :, :, -1], (1, 0)))) | ||
| new_states = torch.einsum("bhzc,bchpn->bzhpn", decay_chunk, states) | ||
| states, final_state = new_states[:, :-1], new_states[:, -1] | ||
|
|
||
| # 4. Compute state -> output conversion per chunk | ||
| # (left term of low-rank factorization of off-diagonal blocks; C terms) | ||
| state_decay_out = torch.exp(A_cumsum) | ||
| Y_off = torch.einsum('bclhn,bchpn,bhcl->bclhp', C, states, state_decay_out) | ||
|
|
||
| # Add output of intra-chunk and inter-chunk terms | ||
| # (diagonal and off-diagonal blocks) | ||
| Y = rearrange(Y_diag + Y_off, "b c l h p -> b (c l) h p") | ||
| return Y, final_state | ||
|
|
||
|
|
||
| def generate_random_inputs(batch_size, | ||
| seqlen, | ||
| n_heads, | ||
| d_head, | ||
| itype, | ||
| device='cuda'): | ||
|
|
||
| current_platform.seed_everything(0) | ||
| A = (-torch.exp(torch.rand(n_heads, dtype=itype, device=device))) | ||
| dt = F.softplus( | ||
| torch.randn(batch_size, seqlen, n_heads, dtype=itype, device=device) - | ||
| 4) | ||
| X = torch.randn((batch_size, seqlen, n_heads, d_head), | ||
| dtype=itype, | ||
| device=device) | ||
| B = torch.randn((batch_size, seqlen, n_heads, d_head), | ||
| dtype=itype, | ||
| device=device) | ||
| C = torch.randn((batch_size, seqlen, n_heads, d_head), | ||
| dtype=itype, | ||
| device=device) | ||
|
|
||
| return A, dt, X, B, C | ||
|
|
||
|
|
||
| def generate_continous_batched_examples(example_lens_by_batch, | ||
| num_examples, | ||
| full_length, | ||
| last_taken, | ||
| exhausted, | ||
| n_heads, | ||
| d_head, | ||
| itype, | ||
| device='cuda'): | ||
|
|
||
| # this function generates a random examples of certain length | ||
| # and then cut according to "example_lens_by_batch" and feed | ||
| # them in continuous batches to the kernels | ||
|
|
||
| # generate the full-length example | ||
| A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads, | ||
| d_head, itype) | ||
|
|
||
| Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), | ||
| A * dt, | ||
| B, | ||
| C, | ||
| block_len=full_length // 4) | ||
|
|
||
| # internal function that outputs a cont batch of examples | ||
| # given a tuple of lengths for each example in the batch | ||
| # e.g., example_lens=(8, 4) means take 8 samples from first eg, | ||
| # 4 examples from second eg, etc | ||
| def get_continuous_batch(example_lens: Tuple[int, ...]): | ||
|
|
||
| indices = [] | ||
| for i, x in enumerate(example_lens): | ||
| c = last_taken.get(i, 0) | ||
| indices.append((c, c + x)) | ||
| last_taken[i] = (c + x) % full_length | ||
| exhausted[i] = last_taken[i] == 0 | ||
|
|
||
| return (torch.concat([x[i, s:e] for i, (s, e) in enumerate(indices) | ||
| ]).unsqueeze(0) for x in (dt, X, B, C)) | ||
|
|
||
| # internal function that maps "n" to the appropriate right boundary | ||
| # value when forming continuous batches from examples of length given | ||
| # by "full_length". | ||
| # - e.g., when n > full_length, returns n % full_length | ||
| # when n == full_length, returns full_length | ||
| def end_boundary(n: int): | ||
| return n - ((n - 1) // full_length) * full_length | ||
|
|
||
| IND_E = None | ||
| for spec in example_lens_by_batch: | ||
|
|
||
| # get the (maybe partial) example seen in this cont batch | ||
| dt2, X2, B2, C2 = get_continuous_batch(spec) | ||
|
|
||
| # get the metadata | ||
| cu_seqlens = torch.tensor((0, ) + spec, device=device).cumsum(dim=0) | ||
| sed_idx = torch.zeros(cu_seqlens[-1], | ||
| dtype=torch.int32, | ||
| device=cu_seqlens.device) | ||
| for i, (srt, end) in enumerate(zip( | ||
| cu_seqlens, | ||
| cu_seqlens[1:], | ||
| )): | ||
| sed_idx[srt:end] = i | ||
|
|
||
| # for cont batch | ||
| if IND_E is None: | ||
| IND_S = [0 for _ in range(len(spec))] | ||
| else: | ||
| IND_S = [x % full_length for x in IND_E] | ||
| IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)] | ||
|
|
||
| yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)], | ||
| cu_seqlens, sed_idx.unsqueeze(0), (A, dt2, X2, B2, C2)) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("itype", | ||
| [torch.float32, torch.float16, torch.bfloat16]) | ||
| @pytest.mark.parametrize("n_heads", [3, 4, 11, 16, 32]) | ||
| @pytest.mark.parametrize("d_head", [5, 8, 19, 32, 128]) | ||
| @pytest.mark.parametrize("seq_len_chunk_size", [(119, 17), (128, 32)]) | ||
| def test_mamba_chunk_scan_single_example(d_head, n_heads, seq_len_chunk_size, | ||
| itype): | ||
|
|
||
| # this tests the kernels on a single example (no batching) | ||
|
|
||
| # set seed | ||
| batch_size = 1 # batch_size | ||
| # ssd_minimal_discrete requires chunk_size divide seqlen | ||
| # - this is only required for generating the reference seqs, | ||
| # it is not an operational limitation. | ||
| seqlen, chunk_size = seq_len_chunk_size | ||
|
|
||
| A, dt, X, B, C = generate_random_inputs(batch_size, seqlen, n_heads, | ||
| d_head, itype) | ||
|
|
||
| Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1), A * dt, | ||
| B, C, chunk_size) | ||
|
|
||
| Y, final_state = mamba_chunk_scan_combined(X, | ||
| dt, | ||
| A, | ||
| B, | ||
| C, | ||
| chunk_size, | ||
| D=None, | ||
| return_final_states=True) | ||
|
|
||
| # just test the last in sequence | ||
| torch.allclose(Y[:, -1], Y_min[:, -1], atol=1e-3, rtol=1e-3) | ||
|
|
||
| # just test the last head | ||
| # NOTE, in the kernel we always cast states to fp32 | ||
| torch.allclose(final_state[:, -1], | ||
| final_state_min[:, -1].to(torch.float32), | ||
| atol=1e-3, | ||
| rtol=1e-3) | ||
|
|
||
|
|
||
| @pytest.mark.parametrize("itype", [torch.float32, torch.float16]) | ||
| @pytest.mark.parametrize("n_heads", [4, 8, 13]) | ||
| @pytest.mark.parametrize("d_head", [5, 16, 21, 32]) | ||
| @pytest.mark.parametrize( | ||
| "seq_len_chunk_size_cases", | ||
| [ | ||
|
|
||
| # small-ish chunk_size (8) | ||
| (64, 8, 2, [(64, 32), (64, 32)]), | ||
| (64, 8, 2, [(32, 32), (32, 32), (32, 32)]), | ||
| (64, 8, 2, [(8, 8), (8, 8), (8, 8)]), # chunk size boundary | ||
| (64, 8, 2, [(4, 4), (4, 4), (4, 4), | ||
| (4, 4)]), # chunk_size larger than cont batches | ||
| (64, 8, 5, [ | ||
| (64, 32, 16, 8, 8), | ||
| (8, 16, 32, 16, 8), | ||
| (8, 8, 16, 32, 16), | ||
| ]), # mode examples with varied lengths | ||
|
|
||
| # odd chunk_size | ||
| (64, 29, 2, [(11, 4), (13, 23), (19, 22), | ||
| (21, 15)]), # irregular sizes | ||
|
|
||
| # large-ish chunk_size (256) | ||
| (64, 256, 1, [(5, ), (1, ), (1, ), | ||
| (1, )]), # irregular sizes with small sequences | ||
| (64, 256, 2, [(5, 30), (1, 2), (1, 2), | ||
| (1, 2)]), # irregular sizes with small sequences | ||
| ]) | ||
| def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases, | ||
| itype): | ||
|
|
||
| # this test with multiple examples in a continuous batch | ||
| # (i.e. chunked prefill) | ||
|
|
||
| seqlen, chunk_size, num_examples, cases = seq_len_chunk_size_cases | ||
|
|
||
| # hold state during the cutting process so we know if an | ||
| # example has been exhausted and needs to cycle | ||
| last_taken: Dict = {} # map: eg -> pointer to last taken sample | ||
| exhausted: Dict = {} # map: eg -> boolean indicating example is exhausted | ||
|
|
||
| states = None | ||
| for Y_min, cu_seqlens, sed_idx, (A, dt, X, B, | ||
| C) in generate_continous_batched_examples( | ||
| cases, num_examples, seqlen, | ||
| last_taken, exhausted, n_heads, | ||
| d_head, itype): | ||
|
|
||
| Y, new_states = mamba_chunk_scan_combined( | ||
| X, | ||
| dt, | ||
| A, | ||
| B, | ||
| C, | ||
| chunk_size, | ||
| D=None, | ||
| cu_seqlens=cu_seqlens, | ||
| seq_idx=sed_idx, | ||
| return_varlen_states=True, | ||
| initial_states=states, | ||
| ) | ||
|
|
||
| # just test the last in sequence | ||
| for i in range(num_examples): | ||
|
|
||
| # just test one dim and dstate | ||
| Y_eg = Y[0, cu_seqlens[i]:cu_seqlens[i + 1], 0, 0] | ||
| Y_min_eg = Y_min[i][:, 0, 0] | ||
| torch.allclose(Y_eg, Y_min_eg, atol=1e-3, rtol=1e-3) | ||
|
|
||
| # update states | ||
| states = new_states | ||
| for i, clear in exhausted.items(): | ||
| if clear: | ||
| states[i].fill_(0.) | ||
| exhausted[i] = False |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Uh oh!
There was an error while loading. Please reload this page.