Skip to content

Commit a1d9058

Browse files
authored
[bug fix]kvstar delta kvcache block select bugfix (#341)
* kvstar delta kvcache block select bugfix * clean code * suitable inner attn_begin api * suitable inner attn_finish api
1 parent 95d9a23 commit a1d9058

File tree

1 file changed

+7
-2
lines changed

1 file changed

+7
-2
lines changed

ucm/sparse/kvstar/multistep.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
import enum
22
import math
33
from dataclasses import dataclass, field
4-
from typing import Dict, List, Union
4+
from typing import Dict, List, Optional, Union
55

66
import torch
77
from vllm.config import VllmConfig
@@ -352,6 +352,7 @@ def attention_begin(
352352
key: torch.Tensor,
353353
value: torch.Tensor,
354354
forward_context: ForwardContext,
355+
phase: Optional[str] = None,
355356
) -> None:
356357
index_in_batch = self.req_meta.index_in_batch
357358
query_start_loc = self.req_meta.query_start_loc
@@ -446,6 +447,9 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids):
446447
retrieve_result_hash_list = self.step_group_retrieve_result.get(
447448
need_retrieve_record
448449
).copy()
450+
fixed_origin_candidate_swap_vllm_block_ids = (
451+
candidate_swap_vllm_block_ids.copy()
452+
)
449453
if need_retrieve_record != "prefill" or load_step == 1:
450454
if len(self.layer_wise_pre_swap_area_block_hashes) == 0:
451455
self.layer_wise_pre_swap_area_block_hashes = {
@@ -456,7 +460,7 @@ def load_retrieve_result_async(self, load_step, candidate_swap_vllm_block_ids):
456460
}
457461
else:
458462
already_matched_record = {}
459-
for logic_blk_id in candidate_swap_vllm_block_ids:
463+
for logic_blk_id in fixed_origin_candidate_swap_vllm_block_ids:
460464
if (
461465
logic_blk_id in self.layer_wise_pre_swap_area_block_hashes
462466
and self.layer_wise_pre_swap_area_block_hashes[logic_blk_id]
@@ -540,6 +544,7 @@ def attention_finished(
540544
value: torch.Tensor,
541545
attn_output: torch.Tensor,
542546
forward_context: ForwardContext,
547+
phase: Optional[str] = None,
543548
) -> None:
544549
if self.req_meta.stage != ReqStage.PREFILL:
545550
if (

0 commit comments

Comments
 (0)