Skip to content

Commit fb410ee

Browse files
committed
feat: implement chunked prefill and prefix cache for Qwen3 MoE.
1 parent 25e16fa commit fb410ee

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,20 @@ if(USE_NPU)
3232
if(DEVICE_TYPE STREQUAL "USE_A3")
3333
message("downloading a3 arm xllm kernels")
3434
file(DOWNLOAD
35-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm"
35+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a3.arm.rpm"
3636
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
3737
)
3838
else()
3939
if(DEVICE_ARCH STREQUAL "ARM")
4040
message("downloading a2 arm xllm_kernels")
4141
file(DOWNLOAD
42-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm"
42+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.arm.rpm"
4343
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4444
)
4545
else()
4646
message("downloading a2 x86 xllm_kernels")
4747
file(DOWNLOAD
48-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm"
48+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.x86.rpm"
4949
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
5050
)
5151
endif()

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
321321

322322
param.mlpLinearTransposeType = {-1, -1, -1, -1};
323323

324+
param.enableSplitFuse =
325+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
324326
if (quantize_type_.empty()) {
325327
param.moeLinearTransposeType = std::vector<int>{1, 1, -1, 1};
326328
} else {
@@ -894,7 +896,9 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
894896
std::atomic<bool>* event_flag,
895897
int node_id) {
896898
atb::Status st;
897-
if (input_params.global_empty_kv_cache) {
899+
bool is_prefill = input_params.decode_seq_range.second !=
900+
input_params.q_seq_lens.size(0) - 1;
901+
if (is_prefill) {
898902
build_node_variant_pack(prefill_node_,
899903
x,
900904
cos_pos,
@@ -997,6 +1001,14 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack(
9971001
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
9981002
}
9991003

1004+
if (is_prefill &&
1005+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) {
1006+
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) =
1007+
atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens);
1008+
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16).hostData =
1009+
const_cast<int32_t*>(input_params.q_seq_lens_vec.data());
1010+
}
1011+
10001012
for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
10011013
CHECK_THROW(node.inTensors.at(i) == nullptr,
10021014
model_name_ << " inTensor " << i << " is NULL");

xllm/models/llm/qwen3_moe.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
154154
model_args.rope_theta(),
155155
options);
156156

157-
max_seq_len_ = model_args.max_position_embeddings();
158157
#if defined(USE_NPU)
159158
atb_pos_emb_ = layer::PosEmbedding(context);
160-
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
159+
int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1;
161160
attn_mask_ = layer::AttentionMask(options.device(),
162161
options.dtype().toScalarType(),
163162
/*mask_value=*/mask_value);
@@ -251,11 +250,30 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
251250
}
252251

253252
torch::Tensor attn_mask;
254-
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
255-
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
256-
} else {
257-
attn_mask = attn_mask_.gen_free_mask(
258-
num_speculative_tokens_ + 1, dtype_, device_);
253+
max_seq_len_ = FLAGS_enable_chunked_prefill
254+
? std::max(input_params.kv_max_seq_len, max_seq_len_)
255+
: 128;
256+
if (FLAGS_enable_chunked_prefill) {
257+
attn_mask = attn_mask_.get_attn_mask(
258+
max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device());
259+
260+
int batch_size = input_params.q_seq_lens_vec.size();
261+
if (batch_size > 0) {
262+
std::vector<torch::Tensor> req_mask_vec;
263+
req_mask_vec.reserve(batch_size);
264+
265+
for (int j = 0; j < batch_size; j++) {
266+
int start =
267+
input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j];
268+
int end = input_params.kv_seq_lens_vec[j];
269+
270+
auto req_mask_slice = attn_mask.slice(0, start, end);
271+
req_mask_vec.emplace_back(req_mask_slice);
272+
}
273+
attn_mask = torch::cat(req_mask_vec, 0);
274+
}
275+
} else if (input_params.global_empty_kv_cache) {
276+
attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_);
259277
}
260278
auto deep_stacks = input_params.deep_stacks;
261279
int deep_stack_size = deep_stacks.size();

0 commit comments

Comments
 (0)