diff --git a/CMakeLists.txt b/CMakeLists.txt index f18bca2f..1bdcb25e 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -32,20 +32,20 @@ if(USE_NPU) if(DEVICE_TYPE STREQUAL "USE_A3") message("downloading a3 arm xllm kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a3.arm.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() if(DEVICE_ARCH STREQUAL "ARM") message("downloading a2 arm xllm_kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.arm.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) else() message("downloading a2 x86 xllm_kernels") file(DOWNLOAD - "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm" + "https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.x86.rpm" "${CMAKE_BINARY_DIR}/xllm_kernels.rpm" ) endif() diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index ffdf3792..b9c791a5 100755 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -321,6 +321,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters( param.mlpLinearTransposeType = {-1, -1, -1, -1}; + param.enableSplitFuse = + (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill; if (quantize_type_.empty()) { param.moeLinearTransposeType = std::vector{1, 1, -1, 1}; } else { @@ -894,7 +896,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward( std::atomic* event_flag, int node_id) { atb::Status st; - if (input_params.global_empty_kv_cache) { + if (!input_params.batch_forward_type.is_decode()) { build_node_variant_pack(prefill_node_, x, cos_pos, @@ -997,6 +999,14 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); } + if (is_prefill && + (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) { + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) = + atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); + node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16).hostData = + const_cast(input_params.q_seq_lens_vec.data()); + } + for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) { CHECK_THROW(node.inTensors.at(i) == nullptr, model_name_ << " inTensor " << i << " is NULL"); diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index 88e19f33..8ae06904 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -154,10 +154,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module { model_args.rope_theta(), options); - max_seq_len_ = model_args.max_position_embeddings(); #if defined(USE_NPU) atb_pos_emb_ = layer::PosEmbedding(context); - int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; attn_mask_ = layer::AttentionMask(options.device(), options.dtype().toScalarType(), /*mask_value=*/mask_value); @@ -251,11 +250,30 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } torch::Tensor attn_mask; - if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) { - attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_); - } else { - attn_mask = attn_mask_.gen_free_mask( - num_speculative_tokens_ + 1, dtype_, device_); + max_seq_len_ = FLAGS_enable_chunked_prefill + ? std::max(input_params.kv_max_seq_len, max_seq_len_) + : 128; + if (FLAGS_enable_chunked_prefill) { + attn_mask = attn_mask_.get_attn_mask( + max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); + + int batch_size = input_params.q_seq_lens_vec.size(); + if (batch_size > 0) { + std::vector req_mask_vec; + req_mask_vec.reserve(batch_size); + + for (int j = 0; j < batch_size; j++) { + int start = + input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j]; + int end = input_params.kv_seq_lens_vec[j]; + + auto req_mask_slice = attn_mask.slice(0, start, end); + req_mask_vec.emplace_back(req_mask_slice); + } + attn_mask = torch::cat(req_mask_vec, 0); + } + } else if (input_params.global_empty_kv_cache) { + attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_); } auto deep_stacks = input_params.deep_stacks; int deep_stack_size = deep_stacks.size();