File tree Expand file tree Collapse file tree 1 file changed +14
-5
lines changed
Expand file tree Collapse file tree 1 file changed +14
-5
lines changed Original file line number Diff line number Diff line change @@ -17,6 +17,8 @@ limitations under the License.
1717
1818#include < gflags/gflags.h>
1919
20+ #include < unordered_set>
21+
2022#include " common/global_flags.h"
2123
2224namespace xllm {
@@ -555,11 +557,18 @@ void NpuQwen3MoeDecoderLayerImpl::process_general_weights(
555557 int32_t tp_rank = dp_local_tp_rank_;
556558 int32_t tp_size = dp_local_tp_size_;
557559
558- if (index == IN_QKV_WEIGHT_1 || index == IN_QKV_WEIGHT_2 ||
559- index == IN_QKV_BIAS_1 || index == IN_QKV_BIAS_2 ||
560- index == IN_QKV_DESCALE_1 || index == IN_QKV_DESCALE_2 ||
561- index == IN_QKV_OFFSET_1 || index == IN_QKV_OFFSET_2 ||
562- index == IN_QKV_SCALE_1 || index == IN_QKV_SCALE_2) {
560+ static const std::unordered_set<int > qkv_tensor_indices = {IN_QKV_WEIGHT_1,
561+ IN_QKV_WEIGHT_2,
562+ IN_QKV_BIAS_1,
563+ IN_QKV_BIAS_2,
564+ IN_QKV_DESCALE_1,
565+ IN_QKV_DESCALE_2,
566+ IN_QKV_OFFSET_1,
567+ IN_QKV_OFFSET_2,
568+ IN_QKV_SCALE_1,
569+ IN_QKV_SCALE_2};
570+
571+ if (qkv_tensor_indices.count (index) > 0 ) {
563572 if (n_kv_heads_ < dp_local_tp_size_) {
564573 int32_t repeat_times = (dp_local_tp_size_ / n_kv_heads_);
565574
You can’t perform that action at this time.
0 commit comments