@@ -41,7 +41,7 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
4141 dtype_ = options.dtype ().toScalarType ();
4242 num_speculative_tokens_ = model_args.num_speculative_tokens ();
4343 embed_tokens_ =
44- register_module (" embed_tokens" , layer::WordEmbedding (context));
44+ register_module (" embed_tokens" , layer::NpuWordEmbedding (context));
4545
4646 atb_pos_emb_ = layer::PosEmbedding (context);
4747 cos_sin_ = get_concat_rotary_embedding (64 ,
@@ -206,11 +206,12 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
206206 final_norm_->merge_loaded_weights ();
207207 }
208208
209- std::vector<layer::WordEmbedding > get_word_embedding () {
209+ std::vector<layer::NpuWordEmbedding > get_word_embedding () {
210210 return {embed_tokens_};
211211 }
212212
213- void set_word_embedding (std::vector<layer::WordEmbedding>& word_embedding) {
213+ void set_word_embedding (
214+ std::vector<layer::NpuWordEmbedding>& word_embedding) {
214215 embed_tokens_ = word_embedding[0 ];
215216 }
216217
@@ -226,7 +227,7 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
226227 int32_t num_speculative_tokens_ = 0 ;
227228 at::Device device_;
228229 torch::Dtype dtype_;
229- layer::WordEmbedding embed_tokens_{nullptr };
230+ layer::NpuWordEmbedding embed_tokens_{nullptr };
230231 layer::AttentionMask attn_mask_;
231232 torch::Tensor cos_sin_;
232233 layer::PosEmbedding atb_pos_emb_{nullptr };
@@ -289,11 +290,12 @@ class Glm4MoeMtpForCausalLMImpl : public torch::nn::Module {
289290
290291 void set_lm_head (layer::NpuLmHead& head) { lm_head_ = head; }
291292
292- std::vector<layer::WordEmbedding > get_word_embedding () {
293+ std::vector<layer::NpuWordEmbedding > get_word_embedding () {
293294 return model_->get_word_embedding ();
294295 }
295296
296- void set_word_embedding (std::vector<layer::WordEmbedding>& word_embedding) {
297+ void set_word_embedding (
298+ std::vector<layer::NpuWordEmbedding>& word_embedding) {
297299 model_->set_word_embedding (word_embedding);
298300 }
299301
0 commit comments