11import enum
22import math
33from dataclasses import dataclass , field
4- from typing import Dict , List , Union
4+ from typing import Dict , List , Optional , Union
55
66import torch
77from vllm .config import VllmConfig
@@ -352,6 +352,7 @@ def attention_begin(
352352 key : torch .Tensor ,
353353 value : torch .Tensor ,
354354 forward_context : ForwardContext ,
355+ phase : Optional [str ] = None ,
355356 ) -> None :
356357 index_in_batch = self .req_meta .index_in_batch
357358 query_start_loc = self .req_meta .query_start_loc
@@ -446,6 +447,9 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids):
446447 retrieve_result_hash_list = self .step_group_retrieve_result .get (
447448 need_retrieve_record
448449 ).copy ()
450+ fixed_origin_candidate_swap_vllm_block_ids = (
451+ candidate_swap_vllm_block_ids .copy ()
452+ )
449453 if need_retrieve_record != "prefill" or load_step == 1 :
450454 if len (self .layer_wise_pre_swap_area_block_hashes ) == 0 :
451455 self .layer_wise_pre_swap_area_block_hashes = {
@@ -456,7 +460,7 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids):
456460 }
457461 else :
458462 already_matched_record = {}
459- for logic_blk_id in candidate_swap_vllm_block_ids :
463+ for logic_blk_id in fixed_origin_candidate_swap_vllm_block_ids :
460464 if (
461465 logic_blk_id in self .layer_wise_pre_swap_area_block_hashes
462466 and self .layer_wise_pre_swap_area_block_hashes [logic_blk_id ]
@@ -540,6 +544,7 @@ def attention_finished(
540544 value : torch .Tensor ,
541545 attn_output : torch .Tensor ,
542546 forward_context : ForwardContext ,
547+ phase : Optional [str ] = None ,
543548 ) -> None :
544549 if self .req_meta .stage != ReqStage .PREFILL :
545550 if (
0 commit comments