Skip to content

Commit aef5f3a

Browse files
HaoLi980405zbb200819HelenJia98
authored
[bugfix]gsa fix reslotmapping bug (#194)
* deal bug * gpu kpre and bug fixed * deal bug * deal bug * deal bug * clean code * CI * ci --------- Co-authored-by: zbb200819 <1130072360@qq.com> Co-authored-by: xujia <42216276@qq.com>
1 parent 3ac5926 commit aef5f3a

File tree

5 files changed

+70
-47
lines changed

5 files changed

+70
-47
lines changed

examples/offline_inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
2727
kv_connector_extra_config={
2828
"ucm_connector_name": "UcmDram",
2929
"ucm_connector_config": {
30-
"max_cache_size": 5368709120,
30+
"max_cache_size": 53687091200,
3131
"kv_block_size": 262144,
3232
},
3333
"ucm_sparse_method": "GSA",
@@ -37,8 +37,8 @@ def build_llm_with_uc(module_path: str, name: str, model: str):
3737
llm_args = EngineArgs(
3838
model=model,
3939
kv_transfer_config=ktc,
40-
max_model_len=8000,
41-
gpu_memory_utilization=0.8,
40+
max_model_len=40960,
41+
gpu_memory_utilization=0.87,
4242
block_size=128,
4343
)
4444

@@ -81,7 +81,7 @@ def main():
8181
"Write a detailed letter to the leaders of Earth, explaining the most urgent global issue of the 21st "
8282
"century, the root sauses behind it, and a set of scientifically grounded, morally sound, and globally "
8383
"cooperative solutions that transcend culturak and national boundaries. Include both immediate actions "
84-
"and long-term strategies."
84+
"and long-term strategies." * 200
8585
]
8686

8787
sampling_params = SamplingParams(temperature=0, top_p=0.95, max_tokens=100)

ucm/csrc/gsaoffloadops/src/cal_kpre_and_topk.cpp

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,10 @@ void CalKpreAndTopk::SetKpreMethodParam(uint32_t maxBlockNum, uint32_t numHeads,
3232
m_kNumHeads = numHeads;
3333
m_numKpre = numKpre;
3434
auto optionsForKCache = torch::TensorOptions().device("cpu").dtype(torch::kFloat32);
35-
for (uint32_t i = 0; i < m_layerNum; i++) {
36-
torch::Tensor layerKCache = torch::zeros({maxBlockNum, m_kNumHeads, m_blockSize, m_headSize}, optionsForKCache);
37-
m_kCache.push_back(layerKCache);
38-
}
35+
// for (uint32_t i = 0; i < m_layerNum; i++) {
36+
// torch::Tensor layerKCache = torch::zeros({maxBlockNum, m_kNumHeads, m_blockSize, m_headSize}, optionsForKCache);
37+
// m_kCache.push_back(layerKCache);
38+
// }
3939
}
4040

4141
void CalKpreAndTopk::SetKpreCache(std::vector<torch::Tensor>& kpreCache)
@@ -152,10 +152,10 @@ void CalKpreAndTopk::CopyData()
152152
}
153153
SetTopkDataReady(curReq.layerId);
154154
} else {
155-
torch::Tensor kNeeded = curReq.srcTensor.index({curReq.ids}).cpu();
156-
torch::Tensor kCache = kNeeded.to(torch::kFloat32).permute({0, 2, 1, 3});
157-
auto targetTensor = m_kCache[curReq.layerId].slice(0, 0, curReq.ids.size(0));
158-
targetTensor.copy_(kCache);
155+
// torch::Tensor kNeeded = curReq.srcTensor.index({curReq.ids}).cpu();
156+
// torch::Tensor kCache = kNeeded.to(torch::kFloat32).permute({0, 2, 1, 3});
157+
// auto targetTensor = m_kCache[curReq.layerId].slice(0, 0, curReq.ids.size(0));
158+
// targetTensor.copy_(kCache);
159159
SetKpreDataReady(curReq.layerId);
160160
}
161161
if (!m_running) {
@@ -195,7 +195,7 @@ void CalKpreAndTopk::CalForOneLayer(uint32_t curLayer)
195195
{
196196
if (m_needCalPre) {
197197
while(!m_kReady[curLayer].load(std::memory_order_acquire));
198-
CalculateKpre(curLayer);
198+
// CalculateKpre(curLayer);
199199
}
200200
if (m_needCalTopk) {
201201
while(!m_qReady[curLayer].load(std::memory_order_acquire));

ucm/csrc/gsaoffloadops/src/select_topk_block.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@
55
#include <cmath>
66
#include "select_topk_block.h"
77

8-
98
namespace SelectTopkBlock {
109
#define OMP_THREAD_NUM 16u
1110

@@ -48,6 +47,7 @@ void TopkBlockSelector::TopKImpl(const float* scores, uint32_t numScores, uint32
4847
for (uint32_t i = 0; i < endWindow_; ++i) {
4948
topkIndices[idx++] = numScores - endWindow_ + i;
5049
}
50+
std::sort(topkIndices, topkIndices + k);
5151
}
5252

5353
float TopkBlockSelector::ComputeBlockScore(float* qMean, const float* blockBase,

ucm/ucm_sparse/gsa.py

Lines changed: 55 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -117,14 +117,14 @@ def add_req_new(
117117
) -> None:
118118
self.blocks = [x for x in add_req_state.block_ids[0]]
119119
self.index_in_batch = index_in_batch
120-
self._init_slot(offset)
121120
self.num_computed_tokens = add_req_state.num_computed_tokens
122121
self.num_scheduled_tokens = num_scheduled_tokens
123122
self.num_prompt_tokens = len(add_req_state.prompt_token_ids)
124123
self.num_output_tokens = len(add_req_state.output_token_ids)
125124
self.is_use_gsa = (
126125
True if self.num_prompt_tokens > SEG_PREFILL_THRESHOLD else False
127126
)
127+
self._init_slot(offset)
128128

129129
def updata_req_state(
130130
self, num_scheduled_tokens, add_req_state, index_in_batch
@@ -134,31 +134,24 @@ def updata_req_state(
134134
self.num_output_tokens = len(add_req_state.output_token_ids)
135135
self.index_in_batch = index_in_batch
136136
if self.stage() == SequenceStage.PREFILL:
137-
if self.is_last_chunk():
138-
add_blocks = [
139-
x for x in add_req_state.block_ids[0][:-1] if x not in self.blocks
140-
]
141-
else:
142-
add_blocks = [
143-
x for x in add_req_state.block_ids[0] if x not in self.blocks
144-
]
137+
add_blocks = [x for x in add_req_state.block_ids[0] if x not in self.blocks]
145138
self.blocks = [x for x in add_req_state.block_ids[0]]
146139
self._update_slot(add_blocks)
147140
else:
148141
self._get_sparse_and_free_block()
149142
if len(add_req_state.block_ids[0]) != self.sparse_len:
150-
add_blocks = [add_req_state.block_ids[0][-2]]
151-
self._update_slot(add_blocks)
143+
add_blocks = [add_req_state.block_ids[0][-1]]
152144
self.blocks += [add_req_state.block_ids[0][-1]]
153145
self.sparse_len = len(add_req_state.block_ids[0])
146+
self._update_slot(add_blocks)
154147
else:
155148
self.calc_block_table = []
156149
self.calc_repre_slot_mapping = []
157150

158151
def _get_sparse_and_free_block(self):
159152
if self.num_prompt_tokens == self.num_computed_tokens:
160153
blocks_len = len(self.blocks)
161-
if self.num_prompt_tokens > SEG_PREFILL_THRESHOLD:
154+
if self.num_prompt_tokens > SEG_PREFILL_THRESHOLD and PTOPK_PREFETCH_ENABLE:
162155
remain_len = compute_topk_len(blocks_len)
163156
if remain_len > MAX_TOPK_LEN:
164157
prefetch_len = 0
@@ -176,10 +169,7 @@ def _get_sparse_and_free_block(self):
176169
self.prefetch_idx = remain_blocks_idx[
177170
remain_len - LOCAL_WINDOW_SZ : -LOCAL_WINDOW_SZ
178171
]
179-
if PTOPK_PREFETCH_ENABLE:
180-
self.sparse_len = remain_len + prefetch_len
181-
else:
182-
self.sparse_len = blocks_len
172+
self.sparse_len = remain_len + prefetch_len
183173
else:
184174
self.remain_idx = list(range(blocks_len))
185175
self.prefetch_idx = []
@@ -190,14 +180,14 @@ def _get_sparse_and_free_block(self):
190180
self.prefetch_idx = None
191181

192182
def _init_slot(self, offset: int) -> None:
183+
self.repre_slot_mapping = list(range(len(self.blocks)))
184+
self.repre_slot_mapping = [x + offset for x in self.repre_slot_mapping]
193185
if self.is_last_chunk():
194-
self.repre_slot_mapping = list(range(len(self.blocks) - 1))
195186
self.calc_block_table = [x for x in self.blocks[:-1]]
187+
self.calc_repre_slot_mapping = [x for x in self.repre_slot_mapping[:-1]]
196188
else:
197-
self.repre_slot_mapping = list(range(len(self.blocks)))
198189
self.calc_block_table = [x for x in self.blocks]
199-
self.repre_slot_mapping = [x + offset for x in self.repre_slot_mapping]
200-
self.calc_repre_slot_mapping = [x for x in self.repre_slot_mapping]
190+
self.calc_repre_slot_mapping = [x for x in self.repre_slot_mapping]
201191

202192
value = len(self.blocks)
203193
one_mask = [False] * value
@@ -224,8 +214,20 @@ def _update_slot(
224214
self.include_mask.append(True)
225215
self.exclude_mask.append(False)
226216
if add_len > 0:
227-
self.calc_block_table = [x for x in add_blocks]
228-
self.calc_repre_slot_mapping = self.repre_slot_mapping[add_len * -1 :]
217+
if self.stage() == SequenceStage.PREFILL:
218+
if self.is_last_chunk():
219+
self.calc_block_table = [x for x in add_blocks[:-1]]
220+
self.calc_repre_slot_mapping = self.repre_slot_mapping[
221+
add_len * -1 : -1
222+
]
223+
else:
224+
self.calc_block_table = [x for x in add_blocks]
225+
self.calc_repre_slot_mapping = self.repre_slot_mapping[
226+
add_len * -1 :
227+
]
228+
else:
229+
self.calc_block_table = [self.blocks[-1]]
230+
self.calc_repre_slot_mapping = [self.repre_slot_mapping[-1]]
229231
else:
230232
self.calc_block_table = []
231233
self.calc_repre_slot_mapping = []
@@ -269,15 +271,20 @@ def get_model_input(
269271
def trans_input_tensor(self, scheduler_output: SchedulerOutput):
270272
calc_block_table = []
271273
model_input = {}
274+
calc_repre_slot_mappings = []
272275
query_locals = [0]
273276
for req_id, _ in scheduler_output.num_scheduled_tokens.items():
274277
calc_block_table += self.gsa_stats[req_id].calc_block_table
278+
calc_repre_slot_mappings += self.gsa_stats[req_id].calc_repre_slot_mapping
275279
query_locals.append(
276280
query_locals[-1] + scheduler_output.num_scheduled_tokens[req_id]
277281
)
278282
model_input["calc_block_table"] = torch.tensor(
279283
calc_block_table, dtype=torch.int32, device="cpu"
280284
)
285+
model_input["calc_repre_slot_mapping"] = torch.tensor(
286+
calc_repre_slot_mappings, dtype=torch.int32, device="cpu"
287+
)
281288
model_input["query_locals"] = query_locals
282289
return model_input
283290

@@ -544,7 +551,7 @@ def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None:
544551
if req_meta.stage() == SequenceStage.DECODE:
545552
index_in_batch = req_meta.index_in_batch
546553
ids[index_in_batch] = (
547-
self.model_input["query_locals"][index_in_batch] - 1
554+
self.model_input["query_locals"][index_in_batch + 1] - 1
548555
)
549556
self.gsa_q_cache[current_layer_id][index_in_batch].copy_(
550557
query[ids[index_in_batch]]
@@ -560,12 +567,27 @@ def copy_q(self, query: torch.Tensor, current_layer_id: int) -> None:
560567
def copy_k(self, layer_name: str, forward_context: ForwardContext) -> None:
561568
current_layer_id = int(layer_name.split(".")[2])
562569
block_ids = self.model_input["calc_block_table"]
570+
calc_repre_slot_mappings = self.model_input["calc_repre_slot_mapping"]
563571
if len(block_ids) > 0:
564572
attn = forward_context.no_compile_layers
565-
k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0]
566-
result = self.gsa_offload_ops.add_copy_req(
567-
True, current_layer_id, [], k_needed
573+
key_cache_mean_out = (
574+
attn[layer_name]
575+
.kv_cache[forward_context.virtual_engine][0][block_ids]
576+
.mean(dim=1, keepdim=True)
577+
.cpu()
568578
)
579+
self.prefetch_engine.kpre_caches[current_layer_id][
580+
calc_repre_slot_mappings
581+
].copy_(key_cache_mean_out)
582+
k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0]
583+
self.gsa_offload_ops.add_copy_req(True, current_layer_id, [], k_needed)
584+
585+
# if len(block_ids) > 0:
586+
# attn = forward_context.no_compile_layers
587+
# k_needed = attn[layer_name].kv_cache[forward_context.virtual_engine][0]
588+
# self.gsa_offload_ops.add_copy_req(
589+
# True, current_layer_id, [], k_needed
590+
# )
569591

570592
def attention_begin(
571593
self,
@@ -588,20 +610,20 @@ def attention_begin(
588610

589611
if isinstance(forward_context.attn_metadata, dict):
590612
attn_metadata = forward_context.attn_metadata[layer_name]
591-
block_tables = attn_metadata.block_table
592613
else:
593614
attn_metadata = forward_context.attn_metadata
594-
block_tables = attn_metadata.block_tables
595615
if self.prefetch_engine.atb_gsa_enable:
596616
if torch.cuda.is_available():
597-
block_tables = self.model_input["block_tables_mp"][current_layer_id]
617+
attn_metadata.block_table = self.model_input["block_tables_mp"][
618+
current_layer_id
619+
]
598620
attn_metadata.seq_lens = self.model_input["gsa_seq_len"][
599621
current_layer_id
600622
]
601623
else:
602-
block_tables[: len(self.prefetch_engine.req_ids_bs)].copy_(
603-
self.model_input["block_tables_mp"][current_layer_id]
604-
)
624+
attn_metadata.block_tables[
625+
: len(self.prefetch_engine.req_ids_bs)
626+
].copy_(self.model_input["block_tables_mp"][current_layer_id])
605627
attn_metadata.seq_lens.copy_(
606628
self.model_input["gsa_seq_len"][current_layer_id]
607629
)
@@ -734,6 +756,7 @@ def execute_begin(self, scheduler_output: SchedulerOutput):
734756
self.gsa_metadata,
735757
is_topk_done,
736758
)
759+
self.gsa_stats = self.gsa_metadata.gsa_stats
737760
self._start_topk_cal()
738761

739762
def execute_finished(self):

ucm/ucm_sparse/prefetch_engine.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,7 @@ def model_input_del(
153153
list_topk_buf = list(topk_buf_tmp.unbind(dim=0))
154154
list_block_table = list(block_table_tmp.unbind(dim=0))
155155
gsa_len_list = list(gen_len_tmp.unbind(dim=0))
156-
self.is_topk_cal = is_topk_done and self.prefetch_space == 3
156+
self.is_topk_cal = is_topk_done and self.num_token % 3 == 0
157157
gsa_model_input["topk_caches"] = list_topk_buf
158158
gsa_model_input["kpre_caches"] = self.kpre_caches
159159
gsa_model_input["is_topk"] = self.is_topk_cal

0 commit comments

Comments
 (0)