File tree Expand file tree Collapse file tree 3 files changed +6
-3
lines changed
Expand file tree Collapse file tree 3 files changed +6
-3
lines changed Original file line number Diff line number Diff line change @@ -330,7 +330,7 @@ else()
330330endif ()
331331
332332if (USE_NPU)
333- add_definitions (-DUSE_NPU_TORCH)
333+ # add_definitions(-DUSE_NPU_TORCH)
334334 add_definitions (-DUSE_NPU)
335335 add_definitions (-DBUILD_LIBTORCH)
336336 add_definitions (-DTORCH_SETCUSTOMHANDLER=ON )
Original file line number Diff line number Diff line change @@ -45,7 +45,7 @@ class QWen3ModelImpl : public LlmModelImplBase<QWen3DecoderLayer> {
4545 xllm::layer::RmsNorm (
4646 model_args.hidden_size (), model_args.rms_norm_eps (), options));
4747#else
48- norm_ = register_module (" norm" , layer::RmsNorm (context));
48+ norm_ = register_module (" norm" , layer::NpuRmsNorm (context));
4949#endif
5050 for (auto i = 0 ; i < FLAGS_micro_batch_num; i++) {
5151#if defined(USE_NPU_TORCH)
Original file line number Diff line number Diff line change @@ -367,10 +367,13 @@ class Qwen3MoeModelImpl : public torch::nn::Module {
367367 torch::Dtype dtype_;
368368 layer::WordEmbedding embed_tokens_{nullptr };
369369 layer::AttentionMask attn_mask_;
370- layer::RmsNorm norm_{ nullptr };
370+
371371#if defined(USE_NPU)
372372 torch::Tensor cos_sin_;
373373 layer::PosEmbedding atb_pos_emb_{nullptr };
374+ layer::NpuRmsNorm norm_{nullptr };
375+ #else
376+ layer::RmsNorm norm_{nullptr };
374377#endif
375378 std::vector<int64_t > mrope_section_;
376379};
You can’t perform that action at this time.
0 commit comments