Skip to content

Commit fe88df9

Browse files
yingxudengliutongxuan
authored andcommitted
refactor: optimize QKV tensor index lookup using std::unordered_set.
1 parent b44d392 commit fe88df9

File tree

1 file changed

+14
-5
lines changed

1 file changed

+14
-5
lines changed

xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,8 @@ limitations under the License.
1717

1818
#include <gflags/gflags.h>
1919

20+
#include <unordered_set>
21+
2022
#include "common/global_flags.h"
2123

2224
namespace xllm {
@@ -555,11 +557,18 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
555557
int32_t tp_rank = dp_local_tp_rank_;
556558
int32_t tp_size = dp_local_tp_size_;
557559

558-
if (index == IN_QKV_WEIGHT_1 || index == IN_QKV_WEIGHT_2 ||
559-
index == IN_QKV_BIAS_1 || index == IN_QKV_BIAS_2 ||
560-
index == IN_QKV_DESCALE_1 || index == IN_QKV_DESCALE_2 ||
561-
index == IN_QKV_OFFSET_1 || index == IN_QKV_OFFSET_2 ||
562-
index == IN_QKV_SCALE_1 || index == IN_QKV_SCALE_2) {
560+
static const std::unordered_set<int> qkv_tensor_indices = {IN_QKV_WEIGHT_1,
561+
IN_QKV_WEIGHT_2,
562+
IN_QKV_BIAS_1,
563+
IN_QKV_BIAS_2,
564+
IN_QKV_DESCALE_1,
565+
IN_QKV_DESCALE_2,
566+
IN_QKV_OFFSET_1,
567+
IN_QKV_OFFSET_2,
568+
IN_QKV_SCALE_1,
569+
IN_QKV_SCALE_2};
570+
571+
if (qkv_tensor_indices.count(index) > 0) {
563572
if (n_kv_heads_ < dp_local_tp_size_) {
564573
int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);
565574

0 commit comments

Comments
 (0)