Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -32,20 +32,20 @@ if(USE_NPU)
if(DEVICE_TYPE STREQUAL "USE_A3")
message("downloading a3 arm xllm kernels")
file(DOWNLOAD
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a3.arm.rpm"
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a3.arm.rpm"
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
)
else()
if(DEVICE_ARCH STREQUAL "ARM")
message("downloading a2 arm xllm_kernels")
file(DOWNLOAD
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.arm.rpm"
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.arm.rpm"
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
)
else()
message("downloading a2 x86 xllm_kernels")
file(DOWNLOAD
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.3-Linux.a2.x86.rpm"
"https://9n-das-tools.s3.cn-north-1.jdcloud-oss.com/xllm-ai/xllm_kernels/0.7.0/xllm_kernels-1.3.4-Linux.a2.x86.rpm"
"${CMAKE_BINARY_DIR}/xllm_kernels.rpm"
)
endif()
Expand Down
12 changes: 11 additions & 1 deletion xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters(

param.mlpLinearTransposeType = {-1, -1, -1, -1};

param.enableSplitFuse =
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache) && is_prefill;
if (quantize_type_.empty()) {
param.moeLinearTransposeType = std::vector<int>{1, 1, -1, 1};
} else {
Expand Down Expand Up @@ -894,7 +896,7 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
std::atomic<bool>* event_flag,
int node_id) {
atb::Status st;
if (input_params.global_empty_kv_cache) {
if (!input_params.batch_forward_type.is_decode()) {
build_node_variant_pack(prefill_node_,
x,
cos_pos,
Expand Down Expand Up @@ -997,6 +999,14 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack(
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
}

if (is_prefill &&
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) {
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16) =
atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 16).hostData =
const_cast<int32_t*>(input_params.q_seq_lens_vec.data());
}

for (size_t i = 0; i < WEIGHT_COUNT_PER_LAYER; ++i) {
CHECK_THROW(node.inTensors.at(i) == nullptr,
model_name_ << " inTensor " << i << " is NULL");
Expand Down
32 changes: 25 additions & 7 deletions xllm/models/llm/qwen3_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -154,10 +154,9 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
model_args.rope_theta(),
options);

max_seq_len_ = model_args.max_position_embeddings();
#if defined(USE_NPU)
atb_pos_emb_ = layer::PosEmbedding(context);
int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984;
int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1;
attn_mask_ = layer::AttentionMask(options.device(),
options.dtype().toScalarType(),
/*mask_value=*/mask_value);
Expand Down Expand Up @@ -251,11 +250,30 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
}

torch::Tensor attn_mask;
if (num_speculative_tokens_ == 0 || input_params.global_empty_kv_cache) {
attn_mask = attn_mask_.get_attn_mask(128, dtype_, device_);
} else {
attn_mask = attn_mask_.gen_free_mask(
num_speculative_tokens_ + 1, dtype_, device_);
max_seq_len_ = FLAGS_enable_chunked_prefill
? std::max(input_params.kv_max_seq_len, max_seq_len_)
: 128;
if (FLAGS_enable_chunked_prefill) {
attn_mask = attn_mask_.get_attn_mask(
max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device());

int batch_size = input_params.q_seq_lens_vec.size();
if (batch_size > 0) {
std::vector<torch::Tensor> req_mask_vec;
req_mask_vec.reserve(batch_size);

for (int j = 0; j < batch_size; j++) {
int start =
input_params.kv_seq_lens_vec[j] - input_params.q_seq_lens_vec[j];
int end = input_params.kv_seq_lens_vec[j];

auto req_mask_slice = attn_mask.slice(0, start, end);
req_mask_vec.emplace_back(req_mask_slice);
}
attn_mask = torch::cat(req_mask_vec, 0);
}
} else if (input_params.global_empty_kv_cache) {
attn_mask = attn_mask_.get_attn_mask(max_seq_len_, dtype_, device_);
}
auto deep_stacks = input_params.deep_stacks;
int deep_stack_size = deep_stacks.size();
Expand Down