Skip to content
Merged
233 changes: 225 additions & 8 deletions tests/kernels/mamba/test_mamba_ssm_ssd.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch,
n_heads,
d_head,
itype,
device='cuda'):
device='cuda',
return_naive_ref=True):

# this function generates a random examples of certain length
# and then cut according to "example_lens_by_batch" and feed
# them in continuous batches to the kernels
# them in continuous batches to the kernels.
# If if return_naive_ref=True, the naive torch implementation
# ssd_minimal_discrete will be used to compute and return
# reference output.

# generate the full-length example
A, dt, X, B, C = generate_random_inputs(num_examples, full_length, n_heads,
d_head, itype)

Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length // 4)
if return_naive_ref:
Y_min, final_state_min = ssd_minimal_discrete(X * dt.unsqueeze(-1),
A * dt,
B,
C,
block_len=full_length //
4)

# internal function that outputs a cont batch of examples
# given a tuple of lengths for each example in the batch
Expand Down Expand Up @@ -179,7 +185,8 @@ def end_boundary(n: int):
IND_S = [x % full_length for x in IND_E]
IND_E = [end_boundary(x + y) for x, y in zip(IND_S, spec)]

yield ([Y_min[s, IND_S[s]:IND_E[s]] for s in range(num_examples)],
yield ([Y_min[s, IND_S[s]:IND_E[s]]
for s in range(num_examples)] if return_naive_ref else None,
cu_seqlens, seq_idx.unsqueeze(0), (A, dt2, X2, B2, C2))


Expand Down Expand Up @@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
if clear:
states[i].fill_(0.)
exhausted[i] = False


@pytest.mark.parametrize("chunk_size", [8, 256])
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this test looks like its replicating test_mamba_chunk_scan_cont_batch, what is the key difference?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key difference is that this test makes sure prefill chunking is working as expected, without using the pytorch reference implementation. Instead, it compares the kernel output without prefill chunking to concatenated outputs with prefill chunking. This is the most straight-forward way to verify that prefill chunking is working as expected.

Another crucial difference from test_mamba_chunk_scan_cont_batch is that this test tests cases where the sequence length is not a multiple of the mamba chunk size. In other words - cases where a sequence changes in the middle of a mamba chunk. These are the cases which currently fail on main, and require the fixes in this PR. These cases are also no supported in the pytorch implementation (see other discussion), so they can't be easily added to test_mamba_chunk_scan_cont_batch which compares kernel results with the reference pytorch implementation.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I see thats for the explaination, in this case I suggest to put some documentation to explain how test_mamba_chunk_scan_cont_batch_prefill_chunking differs from the previous test, since this test if a little long and its hard to understanding by quick glance.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

makes sense. Done

@pytest.mark.parametrize("seqlens", [
(16, 2, 8, 13),
(270, 88, 212, 203),
(16, 20),
])
def test_mamba_chunk_scan_cont_batch_prefill_chunking(chunk_size, seqlens):

# This test verifies the correctness of the chunked prefill implementation
# in the mamba2 ssd kernels, by comparing concatenation (in the sequence
# dimension) of chunked results with the full sequence result.
# It is different from test_mamba_chunk_scan_cont_batch by:
# 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get
# reference outputs. Instead, it compares chunked kernel outputs to full
# sequence kernel outputs. This is the most straightforward way to
# assert chunked prefill correctness.
# 2. It focuses on cases where sequences change in the middle of mamba
# chunks, and not necessarily on chunk boundaries.

max_seqlen = max(seqlens)
# This test can have larger error for longer sequences
if max_seqlen > 256:
atol, rtol = 1e-2, 5e-3
else:
atol, rtol = 5e-3, 5e-3

num_sequences = len(seqlens)
n_heads = 16
d_head = 64
itype = torch.float32

# hold state during the cutting process so we know if an
# example has been exhausted and needs to cycle
last_taken: dict = {} # map: eg -> pointer to last taken sample
exhausted: dict = {} # map: eg -> boolean indicating example is exhausted
_, cu_seqlens, seq_idx, (A, dt, X, B, C) = next(
generate_continuous_batched_examples([seqlens],
num_sequences,
max_seqlen,
last_taken,
exhausted,
n_heads,
d_head,
itype,
return_naive_ref=False))
seqlens = torch.tensor(seqlens, dtype=torch.int32, device=X.device)
device = X.device

## full seqlen computation
chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
cu_seqlens, chunk_size, cu_seqlens[-1])
Y_ref = torch.empty_like(X)
state_ref = mamba_chunk_scan_combined(
X,
dt,
A,
B,
C,
chunk_size,
D=None,
cu_seqlens=cu_seqlens,
seq_idx=seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_ref,
)

## chunked seqlen computation
# first chunk
chunked_seqlens = seqlens // 2
chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(chunked_seqlens, dim=0)
],
dim=0)
chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(chunked_seqlens), device=device),
chunked_seqlens,
output_size=chunked_cu_seqlens[-1]).unsqueeze(0).to(torch.int32)
chunked_input_seq_len = chunked_cu_seqlens[-1]
X_chunked = torch.zeros_like(X)[:, :chunked_input_seq_len, ...]
dt_chunked = torch.zeros_like(dt)[:, :chunked_input_seq_len, ...]
B_chunked = torch.zeros_like(B)[:, :chunked_input_seq_len, ...]
C_chunked = torch.zeros_like(C)[:, :chunked_input_seq_len, ...]
for i in range(num_sequences):
# fmt: off
chunk_f = lambda x, i: x[:, cu_seqlens[i]:cu_seqlens[i] + chunked_seqlens[i], ...] # noqa: E501

X_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(X, i) # noqa: E501
dt_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(dt, i) # noqa: E501
B_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(B, i) # noqa: E501
C_chunked[:, chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1], ...] = chunk_f(C, i) # noqa: E501
# fmt: on

chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
chunked_cu_seqlens, chunk_size, chunked_cu_seqlens[-1])
Y_partial = torch.empty_like(X_chunked)
partial_state = mamba_chunk_scan_combined(
X_chunked,
dt_chunked,
A,
B_chunked,
C_chunked,
chunk_size,
D=None,
cu_seqlens=chunked_cu_seqlens,
seq_idx=chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=None,
out=Y_partial,
)

# remaining chunk
remaining_chunked_seqlens = seqlens - chunked_seqlens
remaining_chunked_cu_seqlens = torch.cat([
torch.tensor([0], device=device),
torch.cumsum(remaining_chunked_seqlens, dim=0)
],
dim=0)
remaining_chunked_seq_idx = torch.repeat_interleave(
torch.arange(len(remaining_chunked_seqlens), device=device),
remaining_chunked_seqlens,
output_size=remaining_chunked_cu_seqlens[-1]).unsqueeze(0).to(
torch.int32)
remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens[-1]
# fmt: off
remaining_X_chunked = torch.zeros_like(X)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_dt_chunked = torch.zeros_like(dt)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_B_chunked = torch.zeros_like(B)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
remaining_C_chunked = torch.zeros_like(C)[:, :remaining_chunked_input_seq_len, ...] # noqa: E501
for i in range(num_sequences):
remaining_chunk_f = lambda x, i: x[:, cu_seqlens[i] + chunked_seqlens[i]:cu_seqlens[i+1], ...] # noqa: E501

remaining_X_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(X, i) # noqa: E501
remaining_dt_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(dt, i) # noqa: E501
remaining_B_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(B, i) # noqa: E501
remaining_C_chunked[:, remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1], ...] = remaining_chunk_f(C, i) # noqa: E501

# assert input chunking is correct
concat_chunk_f = lambda pt1, pt2, i: torch.cat([
pt1[:,chunked_cu_seqlens[i]:chunked_cu_seqlens[i+1],...],
pt2[:,remaining_chunked_cu_seqlens[i]:remaining_chunked_cu_seqlens[i+1],...],
],
dim=1)
concat_batch_f = lambda pt1, pt2: torch.cat([concat_chunk_f(pt1, pt2, i) for i in range(num_sequences)], dim=1) # noqa: E501
# fmt: on

assert concat_batch_f(X_chunked, remaining_X_chunked).equal(X)
assert concat_batch_f(dt_chunked, remaining_dt_chunked).equal(dt)
assert concat_batch_f(B_chunked, remaining_B_chunked).equal(B)
assert concat_batch_f(C_chunked, remaining_C_chunked).equal(C)

chunk_indices, chunk_offsets = \
_query_start_loc_to_chunk_indices_offsets(
remaining_chunked_cu_seqlens,
chunk_size,
remaining_chunked_cu_seqlens[-1])

Y_chunked = torch.empty_like(remaining_X_chunked)
state_chunked = mamba_chunk_scan_combined(
remaining_X_chunked,
remaining_dt_chunked,
A,
remaining_B_chunked,
remaining_C_chunked,
chunk_size,
D=None,
cu_seqlens=remaining_chunked_cu_seqlens,
seq_idx=remaining_chunked_seq_idx,
chunk_indices=chunk_indices,
chunk_offsets=chunk_offsets,
return_varlen_states=True,
initial_states=partial_state,
out=Y_chunked,
)
Y = concat_batch_f(Y_partial, Y_chunked)

# kernel chunked is same as kernel overall
for i in range(num_sequences):
Y_seq = Y[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
Y_ref_seq = Y_ref[:, cu_seqlens[i]:cu_seqlens[i + 1], ...]
torch.testing.assert_close(
Y_seq[:, :chunked_seqlens[i], ...],
Y_ref_seq[:, :chunked_seqlens[i], ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part1 " + x) # noqa: B023
torch.testing.assert_close(
Y_seq[:, chunked_seqlens[i]:, ...],
Y_ref_seq[:, chunked_seqlens[i]:, ...],
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} output part2 " + x) # noqa: B023

state_seq = state_chunked[i]
state_seq_ref = state_ref[i]
torch.testing.assert_close(
state_seq,
state_seq_ref,
atol=atol,
rtol=rtol,
msg=lambda x: f"seq{i} state " + x) # noqa: B023
3 changes: 3 additions & 0 deletions vllm/model_executor/layers/mamba/ops/ssd_chunk_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -289,6 +289,9 @@ def _chunk_scan_fwd_kernel(

# get the cs at the offset boundary
# - c_off == 0 is a passthrough
# - We need dA_cs at the boundary, defined by c_off - no need
# to increase pointer by pid_m (it is a constant offset,
# i.e. the same for all blocks)
dA_cs_m_boundary = tl.load(
dA_cumsum_ptr + (c_off - 1) * stride_dA_cs_csize,
mask=(((c_off - 1) > -1) and ((c_off) < chunk_size)),
Expand Down
9 changes: 6 additions & 3 deletions vllm/model_executor/layers/mamba/ops/ssd_combined.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,21 +106,24 @@ def _mamba_chunk_scan_combined_fwd(x,
# 3. Compute the inter-chunk SSM recurrence; produces correct SSM states at chunk boundaries
# (middle term of factorization of off-diag blocks; A terms)
# - for handling chunked prefill, this requires i) initial_states
# ii) seq_idx and iii) is_cont_batched to be all specified.
# ii) seq_idx iii) is_cont_batched and (iv) chunk_offsets to be all specified.
# - When a new seq_idx is detected, we will stop passing the prev_state
# and switch accordingly to the init_state corresponding to the new seq_idx.
# - We will also make sure that the dA_cumsum is taken only from the start of the
# sequence (hence we need the full dA_cumsum tensor and not just the values at chunk boundaries)
# - this will ensure that states will be updated with the rightmost flushed seq_idx
# of the previous chunk. This implies that the first chunk of states is either 0
# or equal to init_states of the first example.
states, final_states = _state_passing_fwd(
rearrange(states, "... p n -> ... (p n)"),
dA_cumsum[:, :, :, -1],
dA_cumsum,
initial_states=rearrange(initial_states, "... p n -> ... (p n)")
if initial_states is not None else None,
seq_idx=seq_idx,
chunk_size=chunk_size,
out_dtype=state_dtype if state_dtype is not None else C.dtype,
is_cont_batched=cu_seqlens is not None)
is_cont_batched=cu_seqlens is not None,
chunk_offsets=chunk_offsets)
states, final_states = (rearrange(t, "... (p n) -> ... p n", n=dstate)
for t in [states, final_states])

Expand Down
Loading