@@ -595,21 +595,32 @@ def forward_cuda(
595595 if prefix_caching_enabled :
596596 # If prefix caching is enabled, retrieve the relevant variables
597597 # for prefill and decode
598- last_state_idx_d , last_state_idx_p = torch .split (
599- attn_metadata .last_state_idx , [num_decodes , num_prefills ], dim = 0
598+ block_idx_last_computed_token_d , block_idx_last_computed_token_p = (
599+ torch .split (
600+ attn_metadata .block_idx_last_computed_token ,
601+ [num_decodes , num_prefills ],
602+ dim = 0 ,
603+ )
600604 )
601- current_last_idx_d , current_last_idx_p = torch .split (
602- attn_metadata .current_last_idx , [num_decodes , num_prefills ], dim = 0
605+ block_idx_last_scheduled_token_d , block_idx_last_scheduled_token_p = (
606+ torch .split (
607+ attn_metadata .block_idx_last_scheduled_token ,
608+ [num_decodes , num_prefills ],
609+ dim = 0 ,
610+ )
603611 )
604612 # Prefill-only variables:
605- current_first_idx_p = attn_metadata .current_first_idx_p
606- context_lens_p = attn_metadata .context_lens_p
607- last_computed_offset_p = attn_metadata .last_computed_offset_p
613+ block_idx_first_scheduled_token_p = (
614+ attn_metadata .block_idx_first_scheduled_token_p
615+ )
616+ num_computed_tokens_p = attn_metadata .num_computed_tokens_p
608617 else :
609- last_state_idx_d , last_state_idx_p = None , None
610- current_last_idx_d , current_last_idx_p = None , None
611- current_first_idx_p = None
612- context_lens_p = None
618+ block_idx_last_computed_token_d = None
619+ block_idx_last_computed_token_p = None
620+ block_idx_last_scheduled_token_d = None
621+ block_idx_last_scheduled_token_p = None
622+ block_idx_first_scheduled_token_p = None
623+ num_computed_tokens_p = None
613624
614625 # Preallocate output tensor to avoid memcpy cost for merging prefill
615626 # and decode outputs
@@ -637,7 +648,8 @@ def forward_cuda(
637648 # to by "state_indices_tensor_p".
638649 # In particular, it will always write the state at the
639650 # sequence end.
640- # In addition, "current_first_idx_p" and "current_last_idx_p"
651+ # In addition, "block_idx_first_scheduled_token_p" and
652+ # "block_idx_last_scheduled_token_p"
641653 # are provided (which are pointers into
642654 # "state_indices_tensor_p"), it will write additional cache
643655 # states aligned at "block_size_to_align".
@@ -652,10 +664,10 @@ def forward_cuda(
652664 conv_states = conv_state ,
653665 has_initial_state = has_initial_states_p ,
654666 cache_indices = state_indices_tensor_p ,
655- current_first_idx = current_first_idx_p ,
656- current_last_idx = current_last_idx_p ,
657- initial_state_idx = last_state_idx_p ,
658- context_lens = context_lens_p ,
667+ block_idx_first_scheduled_token = block_idx_first_scheduled_token_p ,
668+ block_idx_last_scheduled_token = block_idx_last_scheduled_token_p ,
669+ initial_state_idx = block_idx_last_computed_token_p ,
670+ num_computed_tokens = num_computed_tokens_p ,
659671 block_size_to_align = mamba_block_size ,
660672 metadata = attn_metadata ,
661673 query_start_loc = query_start_loc_p ,
@@ -669,7 +681,7 @@ def forward_cuda(
669681 kernel_ssm_indices = state_indices_tensor_p
670682 if prefix_caching_enabled :
671683 kernel_ssm_indices = state_indices_tensor_p .gather (
672- 1 , last_state_idx_p .unsqueeze (1 )
684+ 1 , block_idx_last_computed_token_p .unsqueeze (1 )
673685 ).squeeze (1 )
674686 initial_states = torch .where (
675687 has_initial_states_p [:, None , None , None ],
@@ -703,52 +715,76 @@ def forward_cuda(
703715 )
704716
705717 if prefix_caching_enabled :
706- # Save states for sequences with more than just the final state:
707- n_blocks_to_fill = current_last_idx_p - current_first_idx_p
708- for seq_idx in (n_blocks_to_fill > 0 ).nonzero ().squeeze (1 ):
718+ # The chunk_stride is the number of chunks per mamba block
719+ # e.g., if mamba_block_size = 512 and chunk_size = 256,
720+ # then chunk_stride = 2
721+ chunk_stride = mamba_block_size // chunk_size
722+
723+ # Save state for sequences with more than just final state
724+ for seq_idx in range (num_prefills ):
725+ # Block index for the first scheduled token
726+ block_idx_first_scheduled_token = block_idx_first_scheduled_token_p [
727+ seq_idx
728+ ]
729+
730+ # Block index for the last scheduled token
731+ block_idx_last_scheduled_token = block_idx_last_scheduled_token_p [
732+ seq_idx
733+ ]
734+
735+ # Number of blocks that need to be written
736+ n_blocks_to_fill = (
737+ block_idx_last_scheduled_token - block_idx_first_scheduled_token
738+ )
739+
740+ # Skip sequences that don't have any blocks to fill
741+ if n_blocks_to_fill == 0 :
742+ continue
743+
744+ # Look up the state indices
709745 cache_blocks_to_fill = state_indices_tensor_p [
710746 seq_idx ,
711- current_first_idx_p [seq_idx ] : current_first_idx_p [seq_idx ]
712- + n_blocks_to_fill [seq_idx ],
747+ block_idx_first_scheduled_token :block_idx_last_scheduled_token ,
713748 ]
714- # chunks = [0 1 2 3 4 5 6 ...]
715- # First aligned chunk would typically be:
716- # mamba_block_size = 1024, chunk_size = 256
717- # 1024 // 256 - 1 --> chunks[3]
718- # But when last chunk wasn't block aligned:
719- # - last_computed_offset_p[seq_idx] // chunk_size
720- # e.g. 1000 // 256 -> 3 completed --> store chunk[0]
721- # e.g. 513 // 256 -> 2 completed --> store chunk[1] (skip 1)
722- # e.g. 256 // 256 -> 1 completed --> store chunk[2] (skip 2)
723- # e.g. 10 // 256 -> 0 completed --> store chunk[3] (skip 3)
724- chunk_stride = mamba_block_size // chunk_size
725- first_aligned_chunk = (
726- torch .concat (
727- [
728- torch .zeros (
729- 1 ,
730- dtype = last_chunk_indices_p .dtype ,
731- device = last_chunk_indices_p .device ,
732- ),
733- last_chunk_indices_p + 1 ,
734- ]
735- )[seq_idx ]
736- + chunk_stride
737- - 1
738- - last_computed_offset_p [seq_idx ] // chunk_size
749+
750+ # First chunk index for this sequence
751+ if seq_idx == 0 :
752+ first_chunk = 0
753+ else :
754+ first_chunk = 1 + last_chunk_indices_p [seq_idx - 1 ]
755+
756+ # First chunk that is aligned on the mamba block boundary
757+ first_aligned_chunk = first_chunk + chunk_stride - 1
758+
759+ # Calculate the number of computed tokens that were not
760+ # already cached
761+ num_unaligned_computed_tokens = (
762+ num_computed_tokens_p [seq_idx ] % mamba_block_size
739763 )
764+
765+ if num_unaligned_computed_tokens > 0 :
766+ # If the number of computed tokens is not block aligned,
767+ # then we need to shift the index accordingly
768+ first_aligned_chunk -= (
769+ num_unaligned_computed_tokens // chunk_size
770+ )
771+
772+ # Get states to write
740773 from_where = varlen_states [
741774 first_aligned_chunk : first_aligned_chunk
742- + n_blocks_to_fill [ seq_idx ] * chunk_stride : chunk_stride
775+ + n_blocks_to_fill * chunk_stride : chunk_stride
743776 ]
777+
778+ # Write the states
744779 ssm_state [cache_blocks_to_fill ] = from_where
745780
746- # For all seqs, store the last state (Note : might be partial):
781+ # For all seqs, store the last state (note : might be partial):
747782 ssm_state [
748783 state_indices_tensor_p .gather (
749- 1 , current_last_idx_p .unsqueeze (1 )
784+ 1 , block_idx_last_scheduled_token_p .unsqueeze (1 )
750785 ).squeeze (1 )
751786 ] = varlen_states [last_chunk_indices_p ]
787+
752788 else :
753789 # update ssm states
754790 # - varlen state is a (num_prefills, nheads, headdim, dstate)
@@ -759,14 +795,17 @@ def forward_cuda(
759795 if has_decode :
760796 if prefix_caching_enabled :
761797 state_indices_tensor_d_input = state_indices_tensor_d .gather (
762- 1 , last_state_idx_d .unsqueeze (1 )
798+ 1 , block_idx_last_computed_token_d .unsqueeze (1 )
763799 ).squeeze (1 )
764800 state_indices_tensor_d_output = state_indices_tensor_d .gather (
765- 1 , current_last_idx_d .unsqueeze (1 )
801+ 1 , block_idx_last_scheduled_token_d .unsqueeze (1 )
766802 ).squeeze (1 )
767- # Note:
768- # for decode always: current_first_idx_d == current_last_idx_d
769- # at block boundaries: current_first_idx_d > last_state_idx_d
803+ # for decode:
804+ # block_idx_first_scheduled_token_d ==
805+ # block_idx_last_scheduled_token_d
806+ # at block boundaries:
807+ # block_idx_first_scheduled_token_d >
808+ # block_idx_last_computed_token_d
770809 else :
771810 # Without caching, read and write in-place to the same blocks:
772811 state_indices_tensor_d_input = state_indices_tensor_d
@@ -780,8 +819,8 @@ def forward_cuda(
780819 self .conv1d .bias ,
781820 self .activation ,
782821 conv_state_indices = state_indices_tensor_d ,
783- current_last_idx = current_last_idx_d ,
784- initial_state_idx = last_state_idx_d ,
822+ block_idx_last_scheduled_token = block_idx_last_scheduled_token_d ,
823+ initial_state_idx = block_idx_last_computed_token_d ,
785824 )
786825
787826 hidden_states_d , B_d , C_d = split_hidden_states_B_C_fn (hidden_states_B_C_d )
0 commit comments