@@ -115,21 +115,27 @@ def generate_continuous_batched_examples(example_lens_by_batch,
115115 n_heads ,
116116 d_head ,
117117 itype ,
118- device = 'cuda' ):
118+ device = 'cuda' ,
119+ return_naive_ref = True ):
119120
120121 # this function generates a random examples of certain length
121122 # and then cut according to "example_lens_by_batch" and feed
122- # them in continuous batches to the kernels
123+ # them in continuous batches to the kernels.
124+ # If if return_naive_ref=True, the naive torch implementation
125+ # ssd_minimal_discrete will be used to compute and return
126+ # reference output.
123127
124128 # generate the full-length example
125129 A , dt , X , B , C = generate_random_inputs (num_examples , full_length , n_heads ,
126130 d_head , itype )
127131
128- Y_min , final_state_min = ssd_minimal_discrete (X * dt .unsqueeze (- 1 ),
129- A * dt ,
130- B ,
131- C ,
132- block_len = full_length // 4 )
132+ if return_naive_ref :
133+ Y_min , final_state_min = ssd_minimal_discrete (X * dt .unsqueeze (- 1 ),
134+ A * dt ,
135+ B ,
136+ C ,
137+ block_len = full_length //
138+ 4 )
133139
134140 # internal function that outputs a cont batch of examples
135141 # given a tuple of lengths for each example in the batch
@@ -179,7 +185,8 @@ def end_boundary(n: int):
179185 IND_S = [x % full_length for x in IND_E ]
180186 IND_E = [end_boundary (x + y ) for x , y in zip (IND_S , spec )]
181187
182- yield ([Y_min [s , IND_S [s ]:IND_E [s ]] for s in range (num_examples )],
188+ yield ([Y_min [s , IND_S [s ]:IND_E [s ]]
189+ for s in range (num_examples )] if return_naive_ref else None ,
183190 cu_seqlens , seq_idx .unsqueeze (0 ), (A , dt2 , X2 , B2 , C2 ))
184191
185192
@@ -324,3 +331,213 @@ def test_mamba_chunk_scan_cont_batch(d_head, n_heads, seq_len_chunk_size_cases,
324331 if clear :
325332 states [i ].fill_ (0. )
326333 exhausted [i ] = False
334+
335+
336+ @pytest .mark .parametrize ("chunk_size" , [8 , 256 ])
337+ @pytest .mark .parametrize ("seqlens" , [
338+ (16 , 2 , 8 , 13 ),
339+ (270 , 88 , 212 , 203 ),
340+ (16 , 20 ),
341+ ])
342+ def test_mamba_chunk_scan_cont_batch_prefill_chunking (chunk_size , seqlens ):
343+
344+ # This test verifies the correctness of the chunked prefill implementation
345+ # in the mamba2 ssd kernels, by comparing concatenation (in the sequence
346+ # dimension) of chunked results with the full sequence result.
347+ # It is different from test_mamba_chunk_scan_cont_batch by:
348+ # 1. Not using the naive torch implementaion (ssd_minimal_discrete) to get
349+ # reference outputs. Instead, it compares chunked kernel outputs to full
350+ # sequence kernel outputs. This is the most straightforward way to
351+ # assert chunked prefill correctness.
352+ # 2. It focuses on cases where sequences change in the middle of mamba
353+ # chunks, and not necessarily on chunk boundaries.
354+
355+ max_seqlen = max (seqlens )
356+ # This test can have larger error for longer sequences
357+ if max_seqlen > 256 :
358+ atol , rtol = 1e-2 , 5e-3
359+ else :
360+ atol , rtol = 5e-3 , 5e-3
361+
362+ num_sequences = len (seqlens )
363+ n_heads = 16
364+ d_head = 64
365+ itype = torch .float32
366+
367+ # hold state during the cutting process so we know if an
368+ # example has been exhausted and needs to cycle
369+ last_taken : dict = {} # map: eg -> pointer to last taken sample
370+ exhausted : dict = {} # map: eg -> boolean indicating example is exhausted
371+ _ , cu_seqlens , seq_idx , (A , dt , X , B , C ) = next (
372+ generate_continuous_batched_examples ([seqlens ],
373+ num_sequences ,
374+ max_seqlen ,
375+ last_taken ,
376+ exhausted ,
377+ n_heads ,
378+ d_head ,
379+ itype ,
380+ return_naive_ref = False ))
381+ seqlens = torch .tensor (seqlens , dtype = torch .int32 , device = X .device )
382+ device = X .device
383+
384+ ## full seqlen computation
385+ chunk_indices , chunk_offsets = \
386+ _query_start_loc_to_chunk_indices_offsets (
387+ cu_seqlens , chunk_size , cu_seqlens [- 1 ])
388+ Y_ref = torch .empty_like (X )
389+ state_ref = mamba_chunk_scan_combined (
390+ X ,
391+ dt ,
392+ A ,
393+ B ,
394+ C ,
395+ chunk_size ,
396+ D = None ,
397+ cu_seqlens = cu_seqlens ,
398+ seq_idx = seq_idx ,
399+ chunk_indices = chunk_indices ,
400+ chunk_offsets = chunk_offsets ,
401+ return_varlen_states = True ,
402+ initial_states = None ,
403+ out = Y_ref ,
404+ )
405+
406+ ## chunked seqlen computation
407+ # first chunk
408+ chunked_seqlens = seqlens // 2
409+ chunked_cu_seqlens = torch .cat ([
410+ torch .tensor ([0 ], device = device ),
411+ torch .cumsum (chunked_seqlens , dim = 0 )
412+ ],
413+ dim = 0 )
414+ chunked_seq_idx = torch .repeat_interleave (
415+ torch .arange (len (chunked_seqlens ), device = device ),
416+ chunked_seqlens ,
417+ output_size = chunked_cu_seqlens [- 1 ]).unsqueeze (0 ).to (torch .int32 )
418+ chunked_input_seq_len = chunked_cu_seqlens [- 1 ]
419+ X_chunked = torch .zeros_like (X )[:, :chunked_input_seq_len , ...]
420+ dt_chunked = torch .zeros_like (dt )[:, :chunked_input_seq_len , ...]
421+ B_chunked = torch .zeros_like (B )[:, :chunked_input_seq_len , ...]
422+ C_chunked = torch .zeros_like (C )[:, :chunked_input_seq_len , ...]
423+ for i in range (num_sequences ):
424+ # fmt: off
425+ chunk_f = lambda x , i : x [:, cu_seqlens [i ]:cu_seqlens [i ] + chunked_seqlens [i ], ...] # noqa: E501
426+
427+ X_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (X , i ) # noqa: E501
428+ dt_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (dt , i ) # noqa: E501
429+ B_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (B , i ) # noqa: E501
430+ C_chunked [:, chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ], ...] = chunk_f (C , i ) # noqa: E501
431+ # fmt: on
432+
433+ chunk_indices , chunk_offsets = \
434+ _query_start_loc_to_chunk_indices_offsets (
435+ chunked_cu_seqlens , chunk_size , chunked_cu_seqlens [- 1 ])
436+ Y_partial = torch .empty_like (X_chunked )
437+ partial_state = mamba_chunk_scan_combined (
438+ X_chunked ,
439+ dt_chunked ,
440+ A ,
441+ B_chunked ,
442+ C_chunked ,
443+ chunk_size ,
444+ D = None ,
445+ cu_seqlens = chunked_cu_seqlens ,
446+ seq_idx = chunked_seq_idx ,
447+ chunk_indices = chunk_indices ,
448+ chunk_offsets = chunk_offsets ,
449+ return_varlen_states = True ,
450+ initial_states = None ,
451+ out = Y_partial ,
452+ )
453+
454+ # remaining chunk
455+ remaining_chunked_seqlens = seqlens - chunked_seqlens
456+ remaining_chunked_cu_seqlens = torch .cat ([
457+ torch .tensor ([0 ], device = device ),
458+ torch .cumsum (remaining_chunked_seqlens , dim = 0 )
459+ ],
460+ dim = 0 )
461+ remaining_chunked_seq_idx = torch .repeat_interleave (
462+ torch .arange (len (remaining_chunked_seqlens ), device = device ),
463+ remaining_chunked_seqlens ,
464+ output_size = remaining_chunked_cu_seqlens [- 1 ]).unsqueeze (0 ).to (
465+ torch .int32 )
466+ remaining_chunked_input_seq_len = remaining_chunked_cu_seqlens [- 1 ]
467+ # fmt: off
468+ remaining_X_chunked = torch .zeros_like (X )[:, :remaining_chunked_input_seq_len , ...] # noqa: E501
469+ remaining_dt_chunked = torch .zeros_like (dt )[:, :remaining_chunked_input_seq_len , ...] # noqa: E501
470+ remaining_B_chunked = torch .zeros_like (B )[:, :remaining_chunked_input_seq_len , ...] # noqa: E501
471+ remaining_C_chunked = torch .zeros_like (C )[:, :remaining_chunked_input_seq_len , ...] # noqa: E501
472+ for i in range (num_sequences ):
473+ remaining_chunk_f = lambda x , i : x [:, cu_seqlens [i ] + chunked_seqlens [i ]:cu_seqlens [i + 1 ], ...] # noqa: E501
474+
475+ remaining_X_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (X , i ) # noqa: E501
476+ remaining_dt_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (dt , i ) # noqa: E501
477+ remaining_B_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (B , i ) # noqa: E501
478+ remaining_C_chunked [:, remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ], ...] = remaining_chunk_f (C , i ) # noqa: E501
479+
480+ # assert input chunking is correct
481+ concat_chunk_f = lambda pt1 , pt2 , i : torch .cat ([
482+ pt1 [:,chunked_cu_seqlens [i ]:chunked_cu_seqlens [i + 1 ],...],
483+ pt2 [:,remaining_chunked_cu_seqlens [i ]:remaining_chunked_cu_seqlens [i + 1 ],...],
484+ ],
485+ dim = 1 )
486+ concat_batch_f = lambda pt1 , pt2 : torch .cat ([concat_chunk_f (pt1 , pt2 , i ) for i in range (num_sequences )], dim = 1 ) # noqa: E501
487+ # fmt: on
488+
489+ assert concat_batch_f (X_chunked , remaining_X_chunked ).equal (X )
490+ assert concat_batch_f (dt_chunked , remaining_dt_chunked ).equal (dt )
491+ assert concat_batch_f (B_chunked , remaining_B_chunked ).equal (B )
492+ assert concat_batch_f (C_chunked , remaining_C_chunked ).equal (C )
493+
494+ chunk_indices , chunk_offsets = \
495+ _query_start_loc_to_chunk_indices_offsets (
496+ remaining_chunked_cu_seqlens ,
497+ chunk_size ,
498+ remaining_chunked_cu_seqlens [- 1 ])
499+
500+ Y_chunked = torch .empty_like (remaining_X_chunked )
501+ state_chunked = mamba_chunk_scan_combined (
502+ remaining_X_chunked ,
503+ remaining_dt_chunked ,
504+ A ,
505+ remaining_B_chunked ,
506+ remaining_C_chunked ,
507+ chunk_size ,
508+ D = None ,
509+ cu_seqlens = remaining_chunked_cu_seqlens ,
510+ seq_idx = remaining_chunked_seq_idx ,
511+ chunk_indices = chunk_indices ,
512+ chunk_offsets = chunk_offsets ,
513+ return_varlen_states = True ,
514+ initial_states = partial_state ,
515+ out = Y_chunked ,
516+ )
517+ Y = concat_batch_f (Y_partial , Y_chunked )
518+
519+ # kernel chunked is same as kernel overall
520+ for i in range (num_sequences ):
521+ Y_seq = Y [:, cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
522+ Y_ref_seq = Y_ref [:, cu_seqlens [i ]:cu_seqlens [i + 1 ], ...]
523+ torch .testing .assert_close (
524+ Y_seq [:, :chunked_seqlens [i ], ...],
525+ Y_ref_seq [:, :chunked_seqlens [i ], ...],
526+ atol = atol ,
527+ rtol = rtol ,
528+ msg = lambda x : f"seq{ i } output part1 " + x ) # noqa: B023
529+ torch .testing .assert_close (
530+ Y_seq [:, chunked_seqlens [i ]:, ...],
531+ Y_ref_seq [:, chunked_seqlens [i ]:, ...],
532+ atol = atol ,
533+ rtol = rtol ,
534+ msg = lambda x : f"seq{ i } output part2 " + x ) # noqa: B023
535+
536+ state_seq = state_chunked [i ]
537+ state_seq_ref = state_ref [i ]
538+ torch .testing .assert_close (
539+ state_seq ,
540+ state_seq_ref ,
541+ atol = atol ,
542+ rtol = rtol ,
543+ msg = lambda x : f"seq{ i } state " + x ) # noqa: B023
0 commit comments