Skip to content

Commit 67a2090

Browse files
committed
refactor: share single WordEmbedding class across NPU and other hardware.
1 parent 4966a97 commit 67a2090

File tree

18 files changed

+114
-130
lines changed

18 files changed

+114
-130
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -69,9 +69,9 @@ class CausalLM : public torch::nn::Module {
6969
#if defined(USE_NPU)
7070
virtual layer::NpuLmHead get_lm_head() = 0;
7171
virtual void set_lm_head(layer::NpuLmHead& head) = 0;
72-
virtual std::vector<layer::WordEmbedding> get_word_embedding() = 0;
72+
virtual std::vector<layer::NpuWordEmbedding> get_word_embedding() = 0;
7373
virtual void set_word_embedding(
74-
std::vector<layer::WordEmbedding>& embedding) = 0;
74+
std::vector<layer::NpuWordEmbedding>& embedding) = 0;
7575
#endif
7676
};
7777

@@ -113,12 +113,12 @@ class CausalLMImpl : public CausalLM {
113113
model_->set_lm_head(head);
114114
};
115115

116-
std::vector<layer::WordEmbedding> get_word_embedding() override {
116+
std::vector<layer::NpuWordEmbedding> get_word_embedding() override {
117117
return model_->get_word_embedding();
118118
};
119119

120120
void set_word_embedding(
121-
std::vector<layer::WordEmbedding>& embedding) override {
121+
std::vector<layer::NpuWordEmbedding>& embedding) override {
122122
model_->set_word_embedding(embedding);
123123
};
124124
#endif

xllm/core/framework/model/causal_vlm.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,12 +70,12 @@ class CausalVLMImpl : public CausalVLM {
7070
model_->set_lm_head(head);
7171
};
7272

73-
std::vector<layer::WordEmbedding> get_word_embedding() override {
73+
std::vector<layer::NpuWordEmbedding> get_word_embedding() override {
7474
return model_->get_word_embedding();
7575
};
7676

7777
void set_word_embedding(
78-
std::vector<layer::WordEmbedding>& embedding) override {
78+
std::vector<layer::NpuWordEmbedding>& embedding) override {
7979
model_->set_word_embedding(embedding);
8080
};
8181
#endif

xllm/core/layers/word_embedding.h

Lines changed: 3 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -24,35 +24,15 @@ namespace xllm {
2424
namespace layer {
2525

2626
#if defined(USE_NPU)
27-
class WordEmbedding : public torch::nn::ModuleHolder<NpuWordEmbeddingImpl> {
27+
class NpuWordEmbedding : public torch::nn::ModuleHolder<NpuWordEmbeddingImpl> {
2828
public:
2929
using torch::nn::ModuleHolder<NpuWordEmbeddingImpl>::ModuleHolder;
3030
using Impl __attribute__((__unused__)) = NpuWordEmbeddingImpl;
31-
WordEmbedding(const ModelContext& context)
31+
NpuWordEmbedding(const ModelContext& context)
3232
: ModuleHolder(std::make_shared<NpuWordEmbeddingImpl>(context)) {}
3333
};
3434

35-
/**
36-
* TODO: Rename the original WordEmbedding definition to NpuWordEmbedding,
37-
* and define the current one as WordEmbedding to unify NPU's WordEmbedding
38-
* related code with MLU and GPU
39-
*/
40-
41-
class WordEmbeddingNative : public torch::nn::ModuleHolder<WordEmbeddingImpl> {
42-
public:
43-
using torch::nn::ModuleHolder<WordEmbeddingImpl>::ModuleHolder;
44-
using Impl __attribute__((__unused__)) = WordEmbeddingImpl;
45-
WordEmbeddingNative(int64_t num_embeddings,
46-
int64_t embedding_dim,
47-
const ParallelArgs& parallel_args,
48-
const torch::TensorOptions& options)
49-
: ModuleHolder(std::make_shared<WordEmbeddingImpl>(num_embeddings,
50-
embedding_dim,
51-
parallel_args,
52-
options)) {}
53-
};
54-
55-
#else
35+
#endif
5636

5737
class WordEmbedding : public torch::nn::ModuleHolder<WordEmbeddingImpl> {
5838
public:
@@ -68,7 +48,5 @@ class WordEmbedding : public torch::nn::ModuleHolder<WordEmbeddingImpl> {
6848
options)) {}
6949
};
7050

71-
#endif
72-
7351
} // namespace layer
7452
} // namespace xllm

xllm/core/runtime/acl_graph_executor_test.cpp

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -243,13 +243,14 @@ class SimpleCausalLM : public CausalLM {
243243
// Simple implementation for testing
244244
}
245245

246-
std::vector<layer::WordEmbedding> get_word_embedding() override {
246+
std::vector<layer::NpuWordEmbedding> get_word_embedding() override {
247247
// Simple implementation for testing
248-
return std::vector<layer::WordEmbedding>{layer::WordEmbedding(nullptr)};
248+
return std::vector<layer::NpuWordEmbedding>{
249+
layer::NpuWordEmbedding(nullptr)};
249250
}
250251

251252
void set_word_embedding(
252-
std::vector<layer::WordEmbedding>& embedding) override {
253+
std::vector<layer::NpuWordEmbedding>& embedding) override {
253254
// Simple implementation for testing
254255
}
255256

xllm/core/runtime/llm_worker_impl.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,11 @@ class LLMWorkerImpl : public WorkerImpl {
4949

5050
void set_lm_head(layer::NpuLmHead& head) { model_->set_lm_head(head); };
5151

52-
std::vector<layer::WordEmbedding> get_word_embedding() {
52+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
5353
return model_->get_word_embedding();
5454
};
5555

56-
void set_word_embedding(std::vector<layer::WordEmbedding>& embedding) {
56+
void set_word_embedding(std::vector<layer::NpuWordEmbedding>& embedding) {
5757
model_->set_word_embedding(embedding);
5858
};
5959
#endif

xllm/models/llm/deepseek_v2.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -119,7 +119,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
119119
model_args.rope_scaling_original_max_position_embeddings());
120120
float sm_scale = 1.0f;
121121
for (auto i = 0; i < FLAGS_micro_batch_num; i++) {
122-
embed_tokens_.push_back(layer::WordEmbedding(context));
122+
embed_tokens_.push_back(layer::NpuWordEmbedding(context));
123123
pos_embs_.push_back(create_rotary_embedding(model_args,
124124
model_args.rotary_dim(),
125125
inv_freq,
@@ -264,11 +264,12 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
264264
layers_[layer_id]->update_expert_weight();
265265
}
266266

267-
std::vector<layer::WordEmbedding> get_word_embedding() {
267+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
268268
return embed_tokens_;
269269
}
270270

271-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
271+
void set_word_embedding(
272+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
272273
embed_tokens_ = word_embedding;
273274
}
274275

@@ -285,7 +286,7 @@ class DeepseekV2ModelImpl : public torch::nn::Module {
285286
int32_t num_speculative_tokens_ = 0;
286287
at::Device device_;
287288
torch::Dtype dtype_;
288-
std::vector<layer::WordEmbedding> embed_tokens_;
289+
std::vector<layer::NpuWordEmbedding> embed_tokens_;
289290
std::vector<std::shared_ptr<RotaryEmbedding>> pos_embs_;
290291
std::vector<layer::PosEmbedding> atb_pos_embs_;
291292
layer::AttentionMask attn_mask_;
@@ -347,11 +348,12 @@ class DeepseekV2ForCausalLMImpl : public torch::nn::Module {
347348

348349
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
349350

350-
std::vector<layer::WordEmbedding> get_word_embedding() {
351+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
351352
return model_->get_word_embedding();
352353
}
353354

354-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
355+
void set_word_embedding(
356+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
355357
model_->set_word_embedding(word_embedding);
356358
}
357359

xllm/models/llm/deepseek_v2_mtp.h

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -218,11 +218,12 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module {
218218
final_norm_->merge_loaded_weights();
219219
}
220220

221-
std::vector<layer::WordEmbedding> get_word_embedding() {
221+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
222222
return embed_tokens_;
223223
}
224224

225-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
225+
void set_word_embedding(
226+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
226227
embed_tokens_ = word_embedding;
227228
}
228229

@@ -237,7 +238,7 @@ class DeepseekV2MtpModelImpl : public torch::nn::Module {
237238
nlohmann::json mapping_data_;
238239
int32_t num_experts_per_tok_;
239240
at::Device device_;
240-
std::vector<layer::WordEmbedding> embed_tokens_;
241+
std::vector<layer::NpuWordEmbedding> embed_tokens_;
241242
std::vector<std::shared_ptr<RotaryEmbedding>> pos_embs_;
242243
std::vector<layer::PosEmbedding> atb_pos_embs_;
243244
layer::AttentionMask attn_mask_;
@@ -300,11 +301,12 @@ class DeepseekV2MtpForCausalLMImpl : public torch::nn::Module {
300301

301302
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
302303

303-
std::vector<layer::WordEmbedding> get_word_embedding() {
304+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
304305
return model_->get_word_embedding();
305306
}
306307

307-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
308+
void set_word_embedding(
309+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
308310
model_->set_word_embedding(word_embedding);
309311
}
310312

xllm/models/llm/embedding_model_base.h

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,12 +78,12 @@ class LlmForEmbeddingImplBase : public torch::nn::Module {
7878

7979
virtual void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
8080

81-
virtual std::vector<layer::WordEmbedding> get_word_embedding() {
81+
virtual std::vector<layer::NpuWordEmbedding> get_word_embedding() {
8282
return model_->get_word_embedding();
8383
}
8484

8585
virtual void set_word_embedding(
86-
std::vector<layer::WordEmbedding>& word_embedding) {
86+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
8787
model_->set_word_embedding(word_embedding);
8888
}
8989

xllm/models/llm/glm4_moe.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ class Glm4MoeModelImpl : public torch::nn::Module {
8585
dtype_ = options.dtype().toScalarType();
8686
num_speculative_tokens_ = model_args.num_speculative_tokens();
8787
embed_tokens_ =
88-
register_module("embed_tokens", layer::WordEmbedding(context));
88+
register_module("embed_tokens", layer::NpuWordEmbedding(context));
8989

9090
atb_pos_emb_ = layer::PosEmbedding(context);
9191
cos_sin_ = get_concat_rotary_embedding(64,
@@ -221,11 +221,12 @@ class Glm4MoeModelImpl : public torch::nn::Module {
221221
norm_->merge_loaded_weights();
222222
}
223223

224-
std::vector<layer::WordEmbedding> get_word_embedding() {
224+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
225225
return {embed_tokens_};
226226
}
227227

228-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
228+
void set_word_embedding(
229+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
229230
embed_tokens_ = word_embedding[0];
230231
}
231232

@@ -242,7 +243,7 @@ class Glm4MoeModelImpl : public torch::nn::Module {
242243
int32_t num_speculative_tokens_ = 0;
243244
at::Device device_;
244245
torch::Dtype dtype_;
245-
layer::WordEmbedding embed_tokens_{nullptr};
246+
layer::NpuWordEmbedding embed_tokens_{nullptr};
246247
layer::AttentionMask attn_mask_;
247248
layer::NpuRmsNorm norm_{nullptr};
248249
torch::Tensor cos_sin_;
@@ -301,11 +302,12 @@ class Glm4MoeForCausalLMImpl : public torch::nn::Module {
301302

302303
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
303304

304-
std::vector<layer::WordEmbedding> get_word_embedding() {
305+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
305306
return model_->get_word_embedding();
306307
}
307308

308-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
309+
void set_word_embedding(
310+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
309311
model_->set_word_embedding(word_embedding);
310312
}
311313

xllm/models/llm/glm4_moe_mtp.h

Lines changed: 8 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
4141
dtype_ = options.dtype().toScalarType();
4242
num_speculative_tokens_ = model_args.num_speculative_tokens();
4343
embed_tokens_ =
44-
register_module("embed_tokens", layer::WordEmbedding(context));
44+
register_module("embed_tokens", layer::NpuWordEmbedding(context));
4545

4646
atb_pos_emb_ = layer::PosEmbedding(context);
4747
cos_sin_ = get_concat_rotary_embedding(64,
@@ -206,11 +206,12 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
206206
final_norm_->merge_loaded_weights();
207207
}
208208

209-
std::vector<layer::WordEmbedding> get_word_embedding() {
209+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
210210
return {embed_tokens_};
211211
}
212212

213-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
213+
void set_word_embedding(
214+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
214215
embed_tokens_ = word_embedding[0];
215216
}
216217

@@ -226,7 +227,7 @@ class Glm4MoeMtpModelImpl : public torch::nn::Module {
226227
int32_t num_speculative_tokens_ = 0;
227228
at::Device device_;
228229
torch::Dtype dtype_;
229-
layer::WordEmbedding embed_tokens_{nullptr};
230+
layer::NpuWordEmbedding embed_tokens_{nullptr};
230231
layer::AttentionMask attn_mask_;
231232
torch::Tensor cos_sin_;
232233
layer::PosEmbedding atb_pos_emb_{nullptr};
@@ -289,11 +290,12 @@ class Glm4MoeMtpForCausalLMImpl : public torch::nn::Module {
289290

290291
void set_lm_head(layer::NpuLmHead& head) { lm_head_ = head; }
291292

292-
std::vector<layer::WordEmbedding> get_word_embedding() {
293+
std::vector<layer::NpuWordEmbedding> get_word_embedding() {
293294
return model_->get_word_embedding();
294295
}
295296

296-
void set_word_embedding(std::vector<layer::WordEmbedding>& word_embedding) {
297+
void set_word_embedding(
298+
std::vector<layer::NpuWordEmbedding>& word_embedding) {
297299
model_->set_word_embedding(word_embedding);
298300
}
299301

0 commit comments

Comments
 (0)