@@ -117,14 +117,14 @@ def add_req_new(
117117 ) -> None :
118118 self .blocks = [x for x in add_req_state .block_ids [0 ]]
119119 self .index_in_batch = index_in_batch
120- self ._init_slot (offset )
121120 self .num_computed_tokens = add_req_state .num_computed_tokens
122121 self .num_scheduled_tokens = num_scheduled_tokens
123122 self .num_prompt_tokens = len (add_req_state .prompt_token_ids )
124123 self .num_output_tokens = len (add_req_state .output_token_ids )
125124 self .is_use_gsa = (
126125 True if self .num_prompt_tokens > SEG_PREFILL_THRESHOLD else False
127126 )
127+ self ._init_slot (offset )
128128
129129 def updata_req_state (
130130 self , num_scheduled_tokens , add_req_state , index_in_batch
@@ -134,31 +134,24 @@ def updata_req_state(
134134 self .num_output_tokens = len (add_req_state .output_token_ids )
135135 self .index_in_batch = index_in_batch
136136 if self .stage () == SequenceStage .PREFILL :
137- if self .is_last_chunk ():
138- add_blocks = [
139- x for x in add_req_state .block_ids [0 ][:- 1 ] if x not in self .blocks
140- ]
141- else :
142- add_blocks = [
143- x for x in add_req_state .block_ids [0 ] if x not in self .blocks
144- ]
137+ add_blocks = [x for x in add_req_state .block_ids [0 ] if x not in self .blocks ]
145138 self .blocks = [x for x in add_req_state .block_ids [0 ]]
146139 self ._update_slot (add_blocks )
147140 else :
148141 self ._get_sparse_and_free_block ()
149142 if len (add_req_state .block_ids [0 ]) != self .sparse_len :
150- add_blocks = [add_req_state .block_ids [0 ][- 2 ]]
151- self ._update_slot (add_blocks )
143+ add_blocks = [add_req_state .block_ids [0 ][- 1 ]]
152144 self .blocks += [add_req_state .block_ids [0 ][- 1 ]]
153145 self .sparse_len = len (add_req_state .block_ids [0 ])
146+ self ._update_slot (add_blocks )
154147 else :
155148 self .calc_block_table = []
156149 self .calc_repre_slot_mapping = []
157150
158151 def _get_sparse_and_free_block (self ):
159152 if self .num_prompt_tokens == self .num_computed_tokens :
160153 blocks_len = len (self .blocks )
161- if self .num_prompt_tokens > SEG_PREFILL_THRESHOLD :
154+ if self .num_prompt_tokens > SEG_PREFILL_THRESHOLD and PTOPK_PREFETCH_ENABLE :
162155 remain_len = compute_topk_len (blocks_len )
163156 if remain_len > MAX_TOPK_LEN :
164157 prefetch_len = 0
@@ -176,10 +169,7 @@ def _get_sparse_and_free_block(self):
176169 self .prefetch_idx = remain_blocks_idx [
177170 remain_len - LOCAL_WINDOW_SZ : - LOCAL_WINDOW_SZ
178171 ]
179- if PTOPK_PREFETCH_ENABLE :
180- self .sparse_len = remain_len + prefetch_len
181- else :
182- self .sparse_len = blocks_len
172+ self .sparse_len = remain_len + prefetch_len
183173 else :
184174 self .remain_idx = list (range (blocks_len ))
185175 self .prefetch_idx = []
@@ -190,14 +180,14 @@ def _get_sparse_and_free_block(self):
190180 self .prefetch_idx = None
191181
192182 def _init_slot (self , offset : int ) -> None :
183+ self .repre_slot_mapping = list (range (len (self .blocks )))
184+ self .repre_slot_mapping = [x + offset for x in self .repre_slot_mapping ]
193185 if self .is_last_chunk ():
194- self .repre_slot_mapping = list (range (len (self .blocks ) - 1 ))
195186 self .calc_block_table = [x for x in self .blocks [:- 1 ]]
187+ self .calc_repre_slot_mapping = [x for x in self .repre_slot_mapping [:- 1 ]]
196188 else :
197- self .repre_slot_mapping = list (range (len (self .blocks )))
198189 self .calc_block_table = [x for x in self .blocks ]
199- self .repre_slot_mapping = [x + offset for x in self .repre_slot_mapping ]
200- self .calc_repre_slot_mapping = [x for x in self .repre_slot_mapping ]
190+ self .calc_repre_slot_mapping = [x for x in self .repre_slot_mapping ]
201191
202192 value = len (self .blocks )
203193 one_mask = [False ] * value
@@ -224,8 +214,20 @@ def _update_slot(
224214 self .include_mask .append (True )
225215 self .exclude_mask .append (False )
226216 if add_len > 0 :
227- self .calc_block_table = [x for x in add_blocks ]
228- self .calc_repre_slot_mapping = self .repre_slot_mapping [add_len * - 1 :]
217+ if self .stage () == SequenceStage .PREFILL :
218+ if self .is_last_chunk ():
219+ self .calc_block_table = [x for x in add_blocks [:- 1 ]]
220+ self .calc_repre_slot_mapping = self .repre_slot_mapping [
221+ add_len * - 1 : - 1
222+ ]
223+ else :
224+ self .calc_block_table = [x for x in add_blocks ]
225+ self .calc_repre_slot_mapping = self .repre_slot_mapping [
226+ add_len * - 1 :
227+ ]
228+ else :
229+ self .calc_block_table = [self .blocks [- 1 ]]
230+ self .calc_repre_slot_mapping = [self .repre_slot_mapping [- 1 ]]
229231 else :
230232 self .calc_block_table = []
231233 self .calc_repre_slot_mapping = []
@@ -269,15 +271,20 @@ def get_model_input(
269271 def trans_input_tensor (self , scheduler_output : SchedulerOutput ):
270272 calc_block_table = []
271273 model_input = {}
274+ calc_repre_slot_mappings = []
272275 query_locals = [0 ]
273276 for req_id , _ in scheduler_output .num_scheduled_tokens .items ():
274277 calc_block_table += self .gsa_stats [req_id ].calc_block_table
278+ calc_repre_slot_mappings += self .gsa_stats [req_id ].calc_repre_slot_mapping
275279 query_locals .append (
276280 query_locals [- 1 ] + scheduler_output .num_scheduled_tokens [req_id ]
277281 )
278282 model_input ["calc_block_table" ] = torch .tensor (
279283 calc_block_table , dtype = torch .int32 , device = "cpu"
280284 )
285+ model_input ["calc_repre_slot_mapping" ] = torch .tensor (
286+ calc_repre_slot_mappings , dtype = torch .int32 , device = "cpu"
287+ )
281288 model_input ["query_locals" ] = query_locals
282289 return model_input
283290
@@ -544,7 +551,7 @@ def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None:
544551 if req_meta .stage () == SequenceStage .DECODE :
545552 index_in_batch = req_meta .index_in_batch
546553 ids [index_in_batch ] = (
547- self .model_input ["query_locals" ][index_in_batch ] - 1
554+ self .model_input ["query_locals" ][index_in_batch + 1 ] - 1
548555 )
549556 self .gsa_q_cache [current_layer_id ][index_in_batch ].copy_ (
550557 query [ids [index_in_batch ]]
@@ -560,12 +567,27 @@ def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None:
560567 def copy_k (self , layer_name : str , forward_context : ForwardContext ) -> None :
561568 current_layer_id = int (layer_name .split ("." )[2 ])
562569 block_ids = self .model_input ["calc_block_table" ]
570+ calc_repre_slot_mappings = self .model_input ["calc_repre_slot_mapping" ]
563571 if len (block_ids ) > 0 :
564572 attn = forward_context .no_compile_layers
565- k_needed = attn [layer_name ].kv_cache [forward_context .virtual_engine ][0 ]
566- result = self .gsa_offload_ops .add_copy_req (
567- True , current_layer_id , [], k_needed
573+ key_cache_mean_out = (
574+ attn [layer_name ]
575+ .kv_cache [forward_context .virtual_engine ][0 ][block_ids ]
576+ .mean (dim = 1 , keepdim = True )
577+ .cpu ()
568578 )
579+ self .prefetch_engine .kpre_caches [current_layer_id ][
580+ calc_repre_slot_mappings
581+ ].copy_ (key_cache_mean_out )
582+ k_needed = attn [layer_name ].kv_cache [forward_context .virtual_engine ][0 ]
583+ self .gsa_offload_ops .add_copy_req (True , current_layer_id , [], k_needed )
584+
585+ # if len(block_ids) > 0:
586+ # attn = forward_context.no_compile_layers
587+ # k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0]
588+ # self.gsa_offload_ops.add_copy_req(
589+ # True, current_layer_id, [], k_needed
590+ # )
569591
570592 def attention_begin (
571593 self ,
@@ -588,20 +610,20 @@ def attention_begin(
588610
589611 if isinstance (forward_context .attn_metadata , dict ):
590612 attn_metadata = forward_context .attn_metadata [layer_name ]
591- block_tables = attn_metadata .block_table
592613 else :
593614 attn_metadata = forward_context .attn_metadata
594- block_tables = attn_metadata .block_tables
595615 if self .prefetch_engine .atb_gsa_enable :
596616 if torch .cuda .is_available ():
597- block_tables = self .model_input ["block_tables_mp" ][current_layer_id ]
617+ attn_metadata .block_table = self .model_input ["block_tables_mp" ][
618+ current_layer_id
619+ ]
598620 attn_metadata .seq_lens = self .model_input ["gsa_seq_len" ][
599621 current_layer_id
600622 ]
601623 else :
602- block_tables [: len ( self . prefetch_engine . req_ids_bs )]. copy_ (
603- self .model_input [ "block_tables_mp" ][ current_layer_id ]
604- )
624+ attn_metadata . block_tables [
625+ : len ( self .prefetch_engine . req_ids_bs )
626+ ]. copy_ ( self . model_input [ "block_tables_mp" ][ current_layer_id ] )
605627 attn_metadata .seq_lens .copy_ (
606628 self .model_input ["gsa_seq_len" ][current_layer_id ]
607629 )
@@ -734,6 +756,7 @@ def execute_begin(self, scheduler_output: SchedulerOutput):
734756 self .gsa_metadata ,
735757 is_topk_done ,
736758 )
759+ self .gsa_stats = self .gsa_metadata .gsa_stats
737760 self ._start_topk_cal ()
738761
739762 def execute_finished (self ):
0 commit comments