Skip to content

Commit 25e16fa

Browse files
authored
refactor: replace decode_seq_range with batch_forward_type. (#451)
1 parent fe88df9 commit 25e16fa

10 files changed

+14
-38
lines changed

xllm/core/framework/batch/batch_input_builder.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -554,8 +554,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
554554
input_params.q_seq_lens_vec = std::move(state_.q_seq_lens);
555555
input_params.new_cache_slots =
556556
torch::tensor(state_.new_token_slot_ids, torch::kInt);
557-
input_params.decode_seq_range =
558-
util::find_ones_indices(input_params.q_seq_lens_vec);
559557

560558
// for flashinfer
561559
input_params.paged_kv_indptr =

xllm/core/framework/model/model_input_params.h

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,6 @@ struct ModelInputParams {
101101
params.block_tables = safe_to(block_tables, device, true);
102102
params.kv_seq_lens_vec = kv_seq_lens_vec;
103103
params.q_seq_lens_vec = q_seq_lens_vec;
104-
params.decode_seq_range = decode_seq_range;
105104

106105
params.input_embedding = safe_to(input_embedding, device);
107106

@@ -153,7 +152,8 @@ struct ModelInputParams {
153152
<< " , q_max_seq_len is " << q_max_seq_len;
154153
LOG(INFO) << "ModelInputParams: kv_seq_lens_vec is " << kv_seq_lens_vec;
155154
LOG(INFO) << "ModelInputParams: q_seq_lens_vec is " << q_seq_lens_vec;
156-
LOG(INFO) << "ModelInputParams: decode_seq_range is " << decode_seq_range;
155+
LOG(INFO) << "ModelInputParams: batch_forward_type is "
156+
<< batch_forward_type.to_string();
157157
print_tensor(kv_seq_lens, "ModelInputParams: kv_seq_lens", 4);
158158
print_tensor(q_seq_lens, "ModelInputParams: q_seq_lens", 4);
159159
print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4);
@@ -172,15 +172,7 @@ struct ModelInputParams {
172172
torch::Tensor kv_seq_lens;
173173
std::vector<int> kv_seq_lens_vec;
174174
std::vector<int> q_seq_lens_vec;
175-
// Range of decode sequence indices in the batch [start, end].
176-
// Decode sequences are identified by q_seq_lens == 1,
177-
// prefill sequences by q_seq_lens > 1 .
178-
// Used to determine whether to use prefill_node_ or
179-
// decode_node_ in NPU layers
180-
// Values: {-1, -1} if no decode requests (all prefill),
181-
// {0, batch_size-1} if all decode requests,
182-
// {start_idx, end_idx} if mixed prefill/decode requests
183-
std::pair<int, int> decode_seq_range;
175+
184176
// max length for qkv.
185177
int32_t kv_max_seq_len = 0;
186178
int32_t q_max_seq_len = 0;

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1090,9 +1090,7 @@ torch::Tensor Glm4MoeDecoderImpl::forward(torch::Tensor& x,
10901090
std::atomic<bool>* event_flag,
10911091
int node_id) {
10921092
atb::Status st;
1093-
bool is_prefill = input_params.decode_seq_range.second !=
1094-
input_params.q_seq_lens.size(0) - 1;
1095-
if (is_prefill) {
1093+
if (!input_params.batch_forward_type.is_decode()) {
10961094
build_node_variant_pack(prefill_node_,
10971095
x,
10981096
cos_pos,

xllm/core/layers/npu/npu_llama_decoder_layer_impl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -277,8 +277,7 @@ torch::Tensor NpuLlamaDecoderLayerImpl::forward(torch::Tensor& x,
277277
int node_id) {
278278
atb::Status st;
279279

280-
if (input_params.decode_seq_range.second !=
281-
input_params.q_seq_lens.size(0) - 1) {
280+
if (!input_params.batch_forward_type.is_decode()) {
282281
build_node_variant_pack(prefill_node_,
283282
x,
284283
cos_pos,

xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -404,8 +404,7 @@ torch::Tensor NpuQwen2DecoderLayerImpl::forward(torch::Tensor& x,
404404
std::atomic<bool>* event_flag,
405405
int node_id) {
406406
atb::Status st;
407-
if (input_params.decode_seq_range.second !=
408-
input_params.q_seq_lens.size(0) - 1) {
407+
if (!input_params.batch_forward_type.is_decode()) {
409408
// mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr);
410409
build_node_variant_pack(prefill_node_,
411410
x,

xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -519,8 +519,7 @@ torch::Tensor NpuQwen3DecoderLayerImpl::forward(torch::Tensor& x,
519519
std::atomic<bool>* event_flag,
520520
int node_id) {
521521
atb::Status st;
522-
if (input_params.decode_seq_range.second !=
523-
input_params.q_seq_lens.size(0) - 1) {
522+
if (!input_params.batch_forward_type.is_decode()) {
524523
// if (input_params.empty_kv_cache) {
525524
// mstxRangeId id = mstxRangeStartA("prefill build variant", nullptr);
526525
build_node_variant_pack(prefill_node_,

xllm/core/runtime/acl_graph_executor_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -629,7 +629,7 @@ bool AclGraph::capture(CausalLM* model,
629629
graph_params.q_seq_lens_vec[i] = 1;
630630
}
631631
graph_params.num_sequences = num_tokens_;
632-
graph_params.decode_seq_range = {0, num_tokens_ - 1};
632+
graph_params.batch_forward_type = BatchForwardType::DECODE;
633633

634634
graph_params.new_cache_slots =
635635
persistent_param_.persistent_new_cache_slots(num_tokens_);

xllm/core/runtime/forward_shared_memory_manager.cpp

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -816,12 +816,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,
816816
forward_input.positions =
817817
create_2d_tensor(std::move(raw_input.m_positions_vec), torch::kInt);
818818
}
819-
std::pair<int, int> decode_seq_range{0, 0};
820-
#if defined(USE_NPU)
821-
if (raw_input.q_seq_lens.size() >= 1) {
822-
decode_seq_range = util::find_ones_indices(raw_input.q_seq_lens);
823-
}
824-
#endif
819+
825820
auto& input_params = forward_input.input_params;
826821
input_params.empty_kv_cache = raw_input.empty_kv_cache;
827822
input_params.global_empty_kv_cache = raw_input.global_empty_kv_cache;
@@ -841,7 +836,7 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,
841836

842837
input_params.new_cache_slots =
843838
torch::tensor(std::move(raw_input.new_token_slot_ids), tensor_options);
844-
input_params.decode_seq_range = decode_seq_range;
839+
845840
util::pad_2d_vector(raw_input.block_tables_vec, 0);
846841
input_params.block_tables =
847842
create_2d_tensor(std::move(raw_input.block_tables_vec), torch::kInt);

xllm/core/runtime/params_utils.cpp

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -177,12 +177,7 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
177177
forward_inputs.acc_logprob = torch::tensor(
178178
acc_logprob_vec,
179179
torch::dtype(torch::kFloat32).device(torch::kCPU).pinned_memory(true));
180-
std::pair<int, int> decode_seq_range{0, 0};
181-
#if defined(USE_NPU)
182-
if (q_seq_lens.size() >= 1) {
183-
decode_seq_range = util::find_ones_indices(q_seq_lens);
184-
}
185-
#endif
180+
186181
auto& input_params = forward_inputs.input_params;
187182
input_params.empty_kv_cache = pb_forward_input->empty_kv_cache();
188183
input_params.global_empty_kv_cache =
@@ -206,7 +201,6 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input,
206201

207202
input_params.new_cache_slots =
208203
torch::tensor(new_token_slot_ids, tensor_options);
209-
input_params.decode_seq_range = decode_seq_range;
210204

211205
util::pad_2d_vector(block_tables_vec, /*pad_value=*/0);
212206
input_params.block_tables =

xllm/core/runtime/speculative_worker_impl.cpp

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -559,11 +559,13 @@ void SpeculativeWorkerImpl::prepare_validate_inputs(
559559
input_params.q_seq_lens = torch::tensor(q_seq_lens_vec, int_options);
560560
input_params.new_cache_slots = torch::tensor(new_token_slot_ids, int_options);
561561
if (!FLAGS_enable_atb_spec_kernel) {
562+
input_params.batch_forward_type = BatchForwardType::CHUNKED_PREFILL;
562563
util::pad_2d_vector(block_tables_vec, /*pad_value=*/0);
563564
input_params.block_tables =
564565
create_2d_tensor(block_tables_vec, torch::kInt).to(device_);
566+
} else {
567+
input_params.batch_forward_type = BatchForwardType::DECODE;
565568
}
566-
input_params.decode_seq_range.second = input_params.num_sequences - 1;
567569

568570
// update the sampling_params
569571
update_sampling_params(

0 commit comments

Comments
 (0)