@@ -111,11 +111,8 @@ class FlashAttentionMetadata(AttentionMetadata):
111111 # Maximum query length in the batch.
112112 max_query_len : Optional [int ]
113113
114- # Number of query tokens for each request in the batch.
115- # Currently, we require that all requests have the same number of query
116- # tokens during the decoding phase. When speculavie decoding is enabled,
117- # decode_query_len might be greater than 1. In all other cases, it is 1.
118- decode_query_len : Optional [int ]
114+ # Max number of query tokens among request in the batch.
115+ max_decode_query_len : Optional [int ]
119116
120117 # Maximum sequence length among prefill batch. 0 if there are decoding
121118 # requests only.
@@ -173,9 +170,9 @@ def prefill_metadata(self) -> Optional["FlashAttentionMetadata"]:
173170 slot_mapping = self .slot_mapping [:self .num_prefill_tokens ],
174171 seq_lens = self .seq_lens [:self .num_prefills ],
175172 seq_lens_tensor = self .seq_lens_tensor [:self .num_prefills ],
176- decode_query_len = 0 ,
177173 max_query_len = self .max_query_len ,
178174 max_prefill_seq_len = self .max_prefill_seq_len ,
175+ max_decode_query_len = 0 ,
179176 max_decode_seq_len = 0 ,
180177 query_start_loc = self .query_start_loc [:self .num_prefills + 1 ],
181178 seq_start_loc = self .seq_start_loc [:self .num_prefills + 1 ],
@@ -202,12 +199,14 @@ def decode_metadata(self) -> Optional["FlashAttentionMetadata"]:
202199 slot_mapping = self .slot_mapping [self .num_prefill_tokens :],
203200 seq_lens = None ,
204201 seq_lens_tensor = self .seq_lens_tensor [self .num_prefills :],
205- decode_query_len = self .decode_query_len ,
202+ max_decode_query_len = self .max_decode_query_len ,
206203 max_query_len = self .max_query_len ,
207204 max_prefill_seq_len = 0 ,
208205 max_decode_seq_len = self .max_decode_seq_len ,
209- query_start_loc = None ,
210- seq_start_loc = None ,
206+ query_start_loc = self .query_start_loc [self .num_prefills :]
207+ if self .query_start_loc is not None else None ,
208+ seq_start_loc = self .seq_start_loc [self .num_prefills :]
209+ if self .seq_start_loc is not None else None ,
211210 context_lens_tensor = None ,
212211 block_tables = self .block_tables [self .num_prefills :],
213212 use_cuda_graph = self .use_cuda_graph ,
@@ -413,9 +412,9 @@ def build(self, seq_lens: List[int], query_lens: List[int],
413412 max_query_len = max (query_lens )
414413 decode_query_lens = query_lens [self .num_prefills :]
415414 if len (decode_query_lens ) > 0 :
416- decode_query_len = max (decode_query_lens )
415+ max_decode_query_len = max (decode_query_lens )
417416 else :
418- decode_query_len = 1
417+ max_decode_query_len = 1
419418 max_prefill_seq_len = max (self .prefill_seq_lens , default = 0 )
420419 max_decode_seq_len = max (self .curr_seq_lens , default = 0 )
421420 num_decode_tokens = self .num_decode_tokens
@@ -468,7 +467,7 @@ def build(self, seq_lens: List[int], query_lens: List[int],
468467 seq_lens = seq_lens ,
469468 seq_lens_tensor = seq_lens_tensor ,
470469 max_query_len = max_query_len ,
471- decode_query_len = decode_query_len ,
470+ max_decode_query_len = max_decode_query_len ,
472471 max_prefill_seq_len = max_prefill_seq_len ,
473472 max_decode_seq_len = max_decode_seq_len ,
474473 query_start_loc = query_start_loc ,
@@ -714,20 +713,37 @@ def unified_flash_attention(
714713
715714 if decode_meta := attn_metadata .decode_metadata :
716715 # Decoding run.
717- _ , num_head , head_dim = decode_query .shape
718- decode_query = decode_query .reshape (- 1 , decode_meta .decode_query_len ,
719- num_head , head_dim )
720- decode_output = flash_attn_with_kvcache (
721- q = decode_query ,
722- k_cache = key_cache ,
723- v_cache = value_cache ,
724- block_table = decode_meta .block_tables ,
725- cache_seqlens = decode_meta .seq_lens_tensor ,
726- softmax_scale = softmax_scale ,
727- causal = True ,
728- alibi_slopes = alibi_slopes ,
729- softcap = logits_soft_cap ,
730- ).squeeze (1 )
716+ # Use flash_attn_varlen_func kernel for speculative decoding
717+ # because different queries might have different lengths.
718+ assert decode_meta .max_decode_query_len is not None
719+ if decode_meta .max_decode_query_len > 1 :
720+ decode_output = flash_attn_varlen_func (
721+ q = decode_query ,
722+ k = key_cache ,
723+ v = value_cache ,
724+ cu_seqlens_q = decode_meta .query_start_loc ,
725+ max_seqlen_q = decode_meta .max_decode_query_len ,
726+ cu_seqlens_k = decode_meta .seq_start_loc ,
727+ max_seqlen_k = decode_meta .max_decode_seq_len ,
728+ softmax_scale = softmax_scale ,
729+ causal = True ,
730+ alibi_slopes = alibi_slopes ,
731+ softcap = logits_soft_cap ,
732+ block_table = decode_meta .block_tables ,
733+ )
734+ else :
735+ # Use flash_attn_with_kvcache for normal decoding.
736+ decode_output = flash_attn_with_kvcache (
737+ q = decode_query .unsqueeze (1 ),
738+ k_cache = key_cache ,
739+ v_cache = value_cache ,
740+ block_table = decode_meta .block_tables ,
741+ cache_seqlens = decode_meta .seq_lens_tensor ,
742+ softmax_scale = softmax_scale ,
743+ causal = True ,
744+ alibi_slopes = alibi_slopes ,
745+ softcap = logits_soft_cap ,
746+ ).squeeze (1 )
731747
732748 if prefill_output is None :
733749 assert decode_output is not None
@@ -739,7 +755,6 @@ def unified_flash_attention(
739755 # Chunked prefill does not work with speculative decoding.
740756 # Therefore, the query length for decode should be 1 in chunked prefill.
741757 assert decode_meta is not None
742- assert decode_meta .decode_query_len == 1
743758 decode_output = decode_output .squeeze (1 )
744759 output = torch .cat ([prefill_output , decode_output ], dim = 0 )
745760 return output .view (num_tokens , hidden_size )
0 commit comments