Skip to content

Commit e20653b

Browse files
committed
feat: implement chunked prefill and prefix cache for Qwen3 MoE.
1 parent 45f8a0a commit e20653b

File tree

3 files changed

+41
-11
lines changed

3 files changed

+41
-11
lines changed

CMakeLists.txt

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,20 @@ if(USE_NPU)
3232
if(DEVICE_TYPE STREQUAL "USE_A3")
3333
message("downloading a3 arm xllm kernels")
3434
file(DOWNLOAD
35-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm"
35+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a3.arm.rpm"
3636
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
3737
)
3838
else()
3939
if(DEVICE_ARCH STREQUAL "ARM")
4040
message("downloading a2 arm xllm_kernels")
4141
file(DOWNLOAD
42-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm"
42+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.arm.rpm"
4343
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
4444
)
4545
else()
4646
message("downloading a2 x86 xllm_kernels")
4747
file(DOWNLOAD
48-
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm"
48+
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.x86.rpm"
4949
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
5050
)
5151
endif()

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -317,6 +317,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(
317317

318318
param.mlpLinearTransposeType = {-1, -1, -1, -1};
319319

320+
param.enableSplitFuse =
321+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
320322
if (quantize_type_.empty()) {
321323
param.moeLinearTransposeType = std::vector<int>{1, 1, -1, 1};
322324
} else {
@@ -865,7 +867,9 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
865867
std::atomic<bool>* event_flag,
866868
int node_id) {
867869
atb::Status st;
868-
if (input_params.global_empty_kv_cache) {
870+
bool is_prefill = input_params.decode_seq_range.second !=
871+
input_params.q_seq_lens.size(0) - 1;
872+
if (is_prefill) {
869873
build_node_variant_pack(prefill_node_,
870874
x,
871875
cos_pos,
@@ -968,6 +972,14 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack(
968972
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
969973
}
970974

975+
if (is_prefill &&
976+
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) {
977+
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) =
978+
atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens);
979+
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16).hostData =
980+
const_cast<int32_t*>(input_params.q_seq_lens_vec.data());
981+
}
982+
971983
for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
972984
CHECK_THROW(node.inTensors.at(i) == nullptr,
973985
model_name_ << " inTensor " << i << " is NULL");

xllm/models/llm/qwen3_moe.h

Lines changed: 25 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -154,10 +154,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
154154
model_args.rope_theta(),
155155
options);
156156

157-
max_seq_len_ = model_args.max_position_embeddings();
158157
#if defined(USE_NPU)
159158
atb_pos_emb_ = layer::PosEmbedding(context);
160-
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
159+
int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1;
161160
attn_mask_ = layer::AttentionMask(options.device(),
162161
options.dtype().toScalarType(),
163162
/*mask_value=*/mask_value);
@@ -248,11 +247,30 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
248247
}
249248

250249
torch::Tensor attn_mask;
251-
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
252-
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
253-
} else {
254-
attn_mask = attn_mask_.gen_free_mask(
255-
num_speculative_tokens_ + 1, dtype_, device_);
250+
max_seq_len_ = FLAGS_enable_chunked_prefill
251+
? std::max(input_params.kv_max_seq_len, max_seq_len_)
252+
: 128;
253+
if (FLAGS_enable_chunked_prefill) {
254+
attn_mask = attn_mask_.get_attn_mask(
255+
max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device());
256+
257+
int batch_size = input_params.q_seq_lens_vec.size();
258+
if (batch_size > 0) {
259+
std::vector<torch::Tensor> req_mask_vec;
260+
req_mask_vec.reserve(batch_size);
261+
262+
for (int j = 0; j < batch_size; j++) {
263+
int start =
264+
input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j];
265+
int end = input_params.kv_seq_lens_vec[j];
266+
267+
auto req_mask_slice = attn_mask.slice(0, start, end);
268+
req_mask_vec.emplace_back(req_mask_slice);
269+
}
270+
attn_mask = torch::cat(req_mask_vec, 0);
271+
}
272+
} else if (input_params.global_empty_kv_cache) {
273+
attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_);
256274
}
257275
auto deep_stacks = input_params.deep_stacks;
258276
int deep_stack_size = deep_stacks.size();

0 commit comments

Comments
 (0)