Skip to content

Commit 059931d

Browse files
[bugfix] fix whl install gsa error and gsa kpre reslotmapping out of range (#204)
* md max seq len bug * clean code --------- Co-authored-by: zbb200819 <1130072360@qq.com>
1 parent 2714d15 commit 059931d

File tree

3 files changed

+11
-8
lines changed

3 files changed

+11
-8
lines changed

setup.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -130,7 +130,7 @@ def _copy_so_files(self, ext: CMakeExtension):
130130
build_install_dir = "ucm/store"
131131
else:
132132
install_dir = GSA_INSTALL_DIR
133-
build_install_dir = "ucm_sparse"
133+
build_install_dir = "ucm/ucm_sparse"
134134

135135
for so_file in so_files:
136136
src_path = os.path.join(so_search_dir, so_file)

ucm/ucm_sparse/gsa.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -457,9 +457,7 @@ def __init__(self, vllm_config: VllmConfig, role: UcmSparseRole):
457457
self.prefetch_engine = GSAPrefetchBase(
458458
vllm_config, 16, True, True, False, 1
459459
)
460-
self.topk_kpre_manger = TopKAndKpreManger(
461-
vllm_config.scheduler_config.max_num_seqs
462-
)
460+
self.topk_kpre_manger = TopKAndKpreManger(MAX_BS)
463461
self.k_cache = {}
464462
self.v_cache = {}
465463
self.tasks_dump = {}
@@ -505,7 +503,7 @@ def init_topk_cal(
505503
self.gsa_q_cache = torch.zeros(
506504
(
507505
self.layer_num,
508-
vllm_config.scheduler_config.max_num_seqs,
506+
MAX_BS,
509507
att_num_heads,
510508
head_size,
511509
),

ucm/ucm_sparse/prefetch_engine.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -169,9 +169,14 @@ def _topk_tmp_deal(self, gsa_metadata, topk_buf_tmp):
169169
for index, topk_info in enumerate(self.topk_bs):
170170
if topk_info[1]:
171171
if topk_info[0] in gsa_metadata.gsa_stats:
172-
gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = (
173-
self.topk_buf_tmp[:, index, : topk_info[2]].clone()
174-
)
172+
if not self.is_cpu_topk:
173+
gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = (
174+
self.topk_buf_tmp[:, index, : topk_info[2]].cpu()
175+
)
176+
else:
177+
gsa_metadata.gsa_stats[topk_info[0]].topk_buf_tmp = (
178+
self.topk_buf_tmp[:, index, : topk_info[2]].clone()
179+
)
175180
self.topk_bs = []
176181
for index, req_id in enumerate(self.req_ids_bs):
177182
one_block_len = len(gsa_metadata.gsa_stats[req_id].blocks)

0 commit comments

Comments
 (0)