Skip to content

Commit d5a217e

Browse files
tomeras91gemini-code-assist[bot]
authored andcommitted
[Bugfix] Fix mamba2 prefill chunking (vllm-project#23279)
Signed-off-by: Tomer Asida <57313761+tomeras91@users.noreply.github.com> Signed-off-by: tomeras91 <57313761+tomeras91@users.noreply.github.com> Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com> Signed-off-by: xuebwang-amd <xuebwang@amd.com>
1 parent 7687771 commit d5a217e

File tree

5 files changed

+348
-34
lines changed

5 files changed

+348
-34
lines changed

tests/kernels/mamba/test_mamba_ssm_ssd.py

Lines changed: 225 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch,
115115
n_heads,
116116
d_head,
117117
itype,
118-
device='cuda'):
118+
device='cuda',
119+
return_naive_ref=True):
119120

120121
# this function generates a random examples of certain length
121122
# and then cut according to "example_lens_by_batch" and feed
122-
# them in continuous batches to the kernels
123+
# them in continuous batches to the kernels.
124+
# If if return_naive_ref=True, the naive torch implementation
125+
# ssd_minimal_discrete will be used to compute and return
126+
# reference output.
123127

124128
# generate the full-length example
125129
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
126130
d_head, itype)
127131

128-
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
129-
A * dt,
130-
B,
131-
C,
132-
block_len=full_length // 4)
132+
if return_naive_ref:
133+
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
134+
A * dt,
135+
B,
136+
C,
137+
block_len=full_length //
138+
4)
133139

134140
# internal function that outputs a cont batch of examples
135141
# given a tuple of lengths for each example in the batch
@@ -179,7 +185,8 @@ def end_boundary(n: int):
179185
IND_S = [x % full_length for x in IND_E]
180186
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]
181187

182-
yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
188+
yield ([Y_min[s, IND_S[s]:IND_E[s]]
189+
for s in range(num_examples)] if return_naive_ref else None,
183190
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))
184191

185192

@@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
324331
if clear:
325332
states[i].fill_(0.)
326333
exhausted[i] = False
334+
335+
336+
@pytest.mark.parametrize("chunk_size", [8, 256])
337+
@pytest.mark.parametrize("seqlens", [
338+
(16, 2, 8, 13),
339+
(270, 88, 212, 203),
340+
(16, 20),
341+
])
342+
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):
343+
344+
# This test verifies the correctness of the chunked prefill implementation
345+
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
346+
# dimension) of chunked results with the full sequence result.
347+
# It is different from test_mamba_chunk_scan_cont_batch by:
348+
# 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get
349+
# reference outputs. Instead, it compares chunked kernel outputs to full
350+
# sequence kernel outputs. This is the most straightforward way to
351+
# assert chunked prefill correctness.
352+
# 2. It focuses on cases where sequences change in the middle of mamba
353+
# chunks, and not necessarily on chunk boundaries.
354+
355+
max_seqlen = max(seqlens)
356+
# This test can have larger error for longer sequences
357+
if max_seqlen > 256:
358+
atol, rtol = 1e-2, 5e-3
359+
else:
360+
atol, rtol = 5e-3, 5e-3
361+
362+
num_sequences = len(seqlens)
363+
n_heads = 16
364+
d_head = 64
365+
itype = torch.float32
366+
367+
# hold state during the cutting process so we know if an
368+
# example has been exhausted and needs to cycle
369+
last_taken: dict = {} # map: eg -> pointer to last taken sample
370+
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
371+
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
372+
generate_continuous_batched_examples([seqlens],
373+
num_sequences,
374+
max_seqlen,
375+
last_taken,
376+
exhausted,
377+
n_heads,
378+
d_head,
379+
itype,
380+
return_naive_ref=False))
381+
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
382+
device = X.device
383+
384+
## full seqlen computation
385+
chunk_indices, chunk_offsets = \
386+
_query_start_loc_to_chunk_indices_offsets(
387+
cu_seqlens, chunk_size, cu_seqlens[-1])
388+
Y_ref = torch.empty_like(X)
389+
state_ref = mamba_chunk_scan_combined(
390+
X,
391+
dt,
392+
A,
393+
B,
394+
C,
395+
chunk_size,
396+
D=None,
397+
cu_seqlens=cu_seqlens,
398+
seq_idx=seq_idx,
399+
chunk_indices=chunk_indices,
400+
chunk_offsets=chunk_offsets,
401+
return_varlen_states=True,
402+
initial_states=None,
403+
out=Y_ref,
404+
)
405+
406+
## chunked seqlen computation
407+
# first chunk
408+
chunked_seqlens = seqlens // 2
409+
chunked_cu_seqlens = torch.cat([
410+
torch.tensor([0], device=device),
411+
torch.cumsum(chunked_seqlens, dim=0)
412+
],
413+
dim=0)
414+
chunked_seq_idx = torch.repeat_interleave(
415+
torch.arange(len(chunked_seqlens), device=device),
416+
chunked_seqlens,
417+
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
418+
chunked_input_seq_len = chunked_cu_seqlens[-1]
419+
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
420+
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
421+
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
422+
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
423+
for i in range(num_sequences):
424+
# fmt: off
425+
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501
426+
427+
X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
428+
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
429+
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
430+
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
431+
# fmt: on
432+
433+
chunk_indices, chunk_offsets = \
434+
_query_start_loc_to_chunk_indices_offsets(
435+
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
436+
Y_partial = torch.empty_like(X_chunked)
437+
partial_state = mamba_chunk_scan_combined(
438+
X_chunked,
439+
dt_chunked,
440+
A,
441+
B_chunked,
442+
C_chunked,
443+
chunk_size,
444+
D=None,
445+
cu_seqlens=chunked_cu_seqlens,
446+
seq_idx=chunked_seq_idx,
447+
chunk_indices=chunk_indices,
448+
chunk_offsets=chunk_offsets,
449+
return_varlen_states=True,
450+
initial_states=None,
451+
out=Y_partial,
452+
)
453+
454+
# remaining chunk
455+
remaining_chunked_seqlens = seqlens - chunked_seqlens
456+
remaining_chunked_cu_seqlens = torch.cat([
457+
torch.tensor([0], device=device),
458+
torch.cumsum(remaining_chunked_seqlens, dim=0)
459+
],
460+
dim=0)
461+
remaining_chunked_seq_idx = torch.repeat_interleave(
462+
torch.arange(len(remaining_chunked_seqlens), device=device),
463+
remaining_chunked_seqlens,
464+
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
465+
torch.int32)
466+
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
467+
# fmt: off
468+
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
469+
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
470+
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
471+
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
472+
for i in range(num_sequences):
473+
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501
474+
475+
remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
476+
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
477+
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
478+
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501
479+
480+
# assert input chunking is correct
481+
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
482+
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
483+
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
484+
],
485+
dim=1)
486+
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
487+
# fmt: on
488+
489+
assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
490+
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
491+
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
492+
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)
493+
494+
chunk_indices, chunk_offsets = \
495+
_query_start_loc_to_chunk_indices_offsets(
496+
remaining_chunked_cu_seqlens,
497+
chunk_size,
498+
remaining_chunked_cu_seqlens[-1])
499+
500+
Y_chunked = torch.empty_like(remaining_X_chunked)
501+
state_chunked = mamba_chunk_scan_combined(
502+
remaining_X_chunked,
503+
remaining_dt_chunked,
504+
A,
505+
remaining_B_chunked,
506+
remaining_C_chunked,
507+
chunk_size,
508+
D=None,
509+
cu_seqlens=remaining_chunked_cu_seqlens,
510+
seq_idx=remaining_chunked_seq_idx,
511+
chunk_indices=chunk_indices,
512+
chunk_offsets=chunk_offsets,
513+
return_varlen_states=True,
514+
initial_states=partial_state,
515+
out=Y_chunked,
516+
)
517+
Y = concat_batch_f(Y_partial, Y_chunked)
518+
519+
# kernel chunked is same as kernel overall
520+
for i in range(num_sequences):
521+
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
522+
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
523+
torch.testing.assert_close(
524+
Y_seq[:, :chunked_seqlens[i], ...],
525+
Y_ref_seq[:, :chunked_seqlens[i], ...],
526+
atol=atol,
527+
rtol=rtol,
528+
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
529+
torch.testing.assert_close(
530+
Y_seq[:, chunked_seqlens[i]:, ...],
531+
Y_ref_seq[:, chunked_seqlens[i]:, ...],
532+
atol=atol,
533+
rtol=rtol,
534+
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023
535+
536+
state_seq = state_chunked[i]
537+
state_seq_ref = state_ref[i]
538+
torch.testing.assert_close(
539+
state_seq,
540+
state_seq_ref,
541+
atol=atol,
542+
rtol=rtol,
543+
msg=lambda x: f"seq{i} state " + x) # noqa: B023

vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(
289289

290290
# get the cs at the offset boundary
291291
# - c_off == 0 is a passthrough
292+
# - We need dA_cs at the boundary, defined by c_off - no need
293+
# to increase pointer by pid_m (it is a constant offset,
294+
# i.e. the same for all blocks)
292295
dA_cs_m_boundary = tl.load(
293296
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
294297
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),

vllm/model_executor/layers/mamba/ops/ssd_combined.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
106106
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
107107
# (middle term of factorization of off-diag blocks; A terms)
108108
# - for handling chunked prefill, this requires i) initial_states
109-
# ii) seq_idx and iii) is_cont_batched to be all specified.
109+
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
110110
# - When a new seq_idx is detected, we will stop passing the prev_state
111111
# and switch accordingly to the init_state corresponding to the new seq_idx.
112+
# - We will also make sure that the dA_cumsum is taken only from the start of the
113+
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
112114
# - this will ensure that states will be updated with the rightmost flushed seq_idx
113115
# of the previous chunk. This implies that the first chunk of states is either 0
114116
# or equal to init_states of the first example.
115117
states, final_states = _state_passing_fwd(
116118
rearrange(states, "... p n -> ... (p n)"),
117-
dA_cumsum[:, :, :, -1],
119+
dA_cumsum,
118120
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
119121
if initial_states is not None else None,
120122
seq_idx=seq_idx,
121123
chunk_size=chunk_size,
122124
out_dtype=state_dtype if state_dtype is not None else C.dtype,
123-
is_cont_batched=cu_seqlens is not None)
125+
is_cont_batched=cu_seqlens is not None,
126+
chunk_offsets=chunk_offsets)
124127
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
125128
for t in [states, final_states])
126129

0 commit comments

Comments
 (0)