diff --git a/xllm/core/layers/common/indexer.h b/xllm/core/layers/common/indexer.h index c45788c00..992fc9749 100644 --- a/xllm/core/layers/common/indexer.h +++ b/xllm/core/layers/common/indexer.h @@ -25,7 +25,7 @@ limitations under the License. #include "../mlu/attention.h" #elif defined(USE_CUDA) #include "../cuda/attention.h" -#endif #include "framework/kv_cache/kv_cache.h" +#endif #include "framework/model/model_input_params.h" #include "framework/parallel_state/parallel_args.h" #include "framework/quant_args.h" diff --git a/xllm/models/CMakeLists.txt b/xllm/models/CMakeLists.txt index ed638c539..be03d2561 100644 --- a/xllm/models/CMakeLists.txt +++ b/xllm/models/CMakeLists.txt @@ -1,6 +1,5 @@ include(cc_library) -# Define the library cc_library( NAME models diff --git a/xllm/models/llm/mlu/deepseek_v2.h b/xllm/models/llm/common/deepseek_v2.h similarity index 99% rename from xllm/models/llm/mlu/deepseek_v2.h rename to xllm/models/llm/common/deepseek_v2.h index 733d3e312..2e552e899 100644 --- a/xllm/models/llm/mlu/deepseek_v2.h +++ b/xllm/models/llm/common/deepseek_v2.h @@ -20,7 +20,7 @@ limitations under the License. #include #include "core/layers/deepseek_v2_decoder_layer.h" -#include "models/llm/llm_model_base.h" +#include "llm_model_base.h" // DeepSeek v2 compatible with huggingface weights // ref to: diff --git a/xllm/models/llm/mlu/deepseek_v3.h b/xllm/models/llm/common/deepseek_v3.h similarity index 100% rename from xllm/models/llm/mlu/deepseek_v3.h rename to xllm/models/llm/common/deepseek_v3.h diff --git a/xllm/models/llm/mlu/deepseek_v32.h b/xllm/models/llm/common/deepseek_v32.h similarity index 100% rename from xllm/models/llm/mlu/deepseek_v32.h rename to xllm/models/llm/common/deepseek_v32.h diff --git a/xllm/models/llm/common/llm_model_base.h b/xllm/models/llm/common/llm_model_base.h new file mode 100644 index 000000000..885942267 --- /dev/null +++ b/xllm/models/llm/common/llm_model_base.h @@ -0,0 +1,240 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include +#include +#include + +#include "core/common/global_flags.h" +#include "core/common/interruption_bus.h" +#include "core/framework/kv_cache/kv_cache.h" +#include "core/framework/model/model_input_params.h" +#include "core/framework/model_context.h" +#include "core/layers/attention_mask.h" +#include "core/layers/common/layer_utils.h" +#include "core/layers/lm_head.h" +#include "core/layers/rms_norm.h" +#include "models/model_registry.h" +#if defined(USE_CUDA) +#include "core/layers/cuda/attention.h" +#endif +#if defined(USE_MLU) +#include "core/layers/mlu/attention.h" +#endif + +namespace xllm { + +template +class LlmDecoderLayerImplBase : public torch::nn::Module { + public: + LlmDecoderLayerImplBase(const ModelContext& context) { + // register submodules + decoder_layer_ = register_module("decoder_layer", DecoderType(context)); + } + + virtual torch::Tensor forward(torch::Tensor& x, + torch::Tensor& positions, + const layer::AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params) { + return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); + } + + // load the weight from the checkpoint + virtual void load_state_dict(const StateDict& state_dict) { + // call each submodule's load_state_dict function + decoder_layer_->load_state_dict(state_dict); + } + + private: + DecoderType decoder_layer_{nullptr}; +}; + +template +class LlmModelImplBase : public torch::nn::Module { + public: + // mode type: qwen2, qwen3 .etc + LlmModelImplBase(const std::string& model_type, const ModelArgs& args) + : model_type_(model_type) { + InterruptionBus::get_instance().subscribe([this](bool interrupted) { + this->layer_forward_interrupted_ = interrupted; + }); + mrope_section_ = args.rope_scaling_mrope_section(); + } + + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return embed_tokens_(input_ids); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + // test + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens); + } + + auto modified_input_params = input_params; + auto position = positions; + layer::update_dummy_run_input(dp_rank_, position, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; + auto attn_metadata = + layer::AttentionMetadata::build(modified_input_params, is_prefill); + + torch::Tensor h_ret; + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + h_ret = layer( + h, position, attn_metadata, kv_caches[i], modified_input_params); + } + return norm_(h_ret); + } + + // load the weight from the checkpoint + virtual void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + + protected: + int max_seq_len_ = 0; + torch::Tensor cos_pos_; + torch::Tensor sin_pos_; + int device_id = 0; + layer::AttentionMask attn_mask_; + int dp_rank_ = 0; + + std::vector mrope_section_; + // test + // ParallelEmbedding embed_tokens_{nullptr}; + layer::WordEmbedding embed_tokens_{nullptr}; + layer::RmsNorm norm_{nullptr}; + + torch::nn::ModuleList blocks_{nullptr}; + // hold same data but different type as blocks_ to avoid type cast + std::vector layers_; + + bool layer_forward_interrupted_ = false; + + private: + std::string model_type_; +}; + +template +class LlmForCausalLMImplBase : public torch::nn::Module { + public: + LlmForCausalLMImplBase(const ModelContext& context) { + tie_word_embeddings = context.get_model_args().tie_word_embeddings(); + // register submodules + model_ = register_module("model", LlmModelType(context)); + + lm_head_ = register_module("lm_head", layer::LmHead(context)); + } + + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return model_->get_input_embeddings(input_ids); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + virtual torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + // select tokens if provided + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + return lm_head_(h); + } + + void load_model(std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + if (tie_word_embeddings) { + lm_head_->load_state_dict( + state_dict->get_dict_with_prefix(prefix + "embed_tokens.")); + } else { + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + } + } + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + virtual layer::LmHead get_lm_head() { return lm_head_; } + + virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; } + + virtual layer::WordEmbedding get_word_embedding() { + return model_->get_word_embedding(); + } + + virtual void set_word_embedding(layer::WordEmbedding& word_embedding) { + model_->set_word_embedding(word_embedding); + } + + protected: + // parameter members, must be registered + LlmModelType model_{nullptr}; + int device_id = 0; + bool tie_word_embeddings{false}; + layer::LmHead lm_head_{nullptr}; +}; + +} // namespace xllm diff --git a/xllm/models/llm/common/qwen2.h b/xllm/models/llm/common/qwen2.h new file mode 100644 index 000000000..4d56229a6 --- /dev/null +++ b/xllm/models/llm/common/qwen2.h @@ -0,0 +1,110 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. +Copyright 2024 The ScaleLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/layers/qwen2_decoder_layer.h" +#include "llm_model_base.h" + +// QWen2 model compatible with huggingface weights +// ref to: +// https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py +namespace xllm { + +class QWen2DecoderLayerImpl + : public LlmDecoderLayerImplBase { + public: + QWen2DecoderLayerImpl(const ModelContext& context) + : LlmDecoderLayerImplBase(context) {} +}; +TORCH_MODULE(QWen2DecoderLayer); + +class QWen2ModelImpl : public LlmModelImplBase { + public: + QWen2ModelImpl(const ModelContext& context) + : LlmModelImplBase("qwen2", context.get_model_args()) { + // register submodules + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + norm_ = register_module("norm", layer::RmsNorm(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1; + attn_mask_ = layer::AttentionMask(options.device(), + options.dtype().toScalarType(), + /*mask_value=*/mask_value); + + for (int32_t i = 0; i < model_args.n_layers(); i++) { + auto block = QWen2DecoderLayer(context); + layers_.push_back(block); + blocks_->push_back(block); + } + } +}; +TORCH_MODULE(QWen2Model); + +class QWen2ForCausalLMImpl : public LlmForCausalLMImplBase { + public: + QWen2ForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} +}; +TORCH_MODULE(QWen2ForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen2, QWen2ForCausalLM); + +// register the model args +// example config: +// https://huggingface.co/Qwen/Qwen2-7B-Instruct/blob/main/config.json +REGISTER_MODEL_ARGS(qwen2, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen2"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); + LOAD_ARG_OR(hidden_size, "hidden_size", 3584); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); + LOAD_ARG_OR(n_heads, "num_attention_heads", 28); + LOAD_ARG(n_kv_heads, "num_key_value_heads"); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(attention_bias, "attention_bias", true); + // LOAD_ARG_OR(no_bias, "no_bias", true); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151643); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + + // For Qwen2/2.5 model < 7B, tie_word_embeddings = true + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(sliding_window, "sliding_window", 4096); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); + +} // namespace xllm diff --git a/xllm/models/llm/common/qwen3.h b/xllm/models/llm/common/qwen3.h new file mode 100644 index 000000000..4decdc100 --- /dev/null +++ b/xllm/models/llm/common/qwen3.h @@ -0,0 +1,148 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include "core/layers/qwen3_decoder_layer.h" +#include "llm_model_base.h" + +namespace xllm { + +class QWen3DecoderLayerImpl + : public LlmDecoderLayerImplBase { + public: + QWen3DecoderLayerImpl(const ModelContext& context) + : LlmDecoderLayerImplBase(context) {} +}; +TORCH_MODULE(QWen3DecoderLayer); + +class QWen3ModelImpl : public LlmModelImplBase { + public: + QWen3ModelImpl(const ModelContext& context) + : LlmModelImplBase("qwen3", context.get_model_args()) { + // register submodules + auto model_args = context.get_model_args(); + auto options = context.get_tensor_options(); + auto parallel_args = context.get_parallel_args(); + auto dp_local_tp_size = + parallel_args.world_size() / parallel_args.dp_size(); + dp_rank_ = parallel_args.rank() / dp_local_tp_size; + + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + norm_ = register_module("norm", layer::RmsNorm(context)); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + + for (int32_t i = 0; i < model_args.n_layers(); i++) { + auto block = QWen3DecoderLayer(context); + layers_.push_back(block); + blocks_->push_back(block); + } + } + + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + torch::Tensor visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + auto local_this = selected + visual_embeds; + hidden_states.index_put_({visual_pos_masks}, local_this); + return hidden_states; + } + + virtual torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + bool use_deepstack = input_params.deep_stacks.size() > 0; + ModelInputParams& input_params_new = + const_cast(input_params); + std::vector deep_stacks; + + if (tokens.numel() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device()); + positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device()); + } + auto inputs_embeds = input_params.input_embedding; + torch::Tensor h; + if (inputs_embeds.defined()) { + h = inputs_embeds; + } else { + h = embed_tokens_(tokens); + } + + auto modified_input_params = input_params; + auto position = positions; + layer::update_dummy_run_input(dp_rank_, position, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; + auto attn_metadata = + layer::AttentionMetadata::build(modified_input_params, is_prefill); + + torch::Tensor h_ret; + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + h_ret = layer( + h, positions, attn_metadata, kv_caches[i], modified_input_params); + } + return norm_(h_ret); + } + + private: + torch::Tensor viusal_pos_mask_; +}; +TORCH_MODULE(QWen3Model); + +class QWen3ForCausalLMImpl : public LlmForCausalLMImplBase { + public: + QWen3ForCausalLMImpl(const ModelContext& context) + : LlmForCausalLMImplBase(context) {} +}; +TORCH_MODULE(QWen3ForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen3, QWen3ForCausalLM); + +// register the model args +REGISTER_MODEL_ARGS(qwen3, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen3"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(vocab_size, "vocab_size", 152064); + LOAD_ARG_OR(hidden_size, "hidden_size", 3584); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 28); + LOAD_ARG_OR(n_heads, "num_attention_heads", 28); + LOAD_ARG(n_kv_heads, "num_key_value_heads"); + // LOAD_ARG_OR(no_bias, "no_bias", true); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151643); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + + // For qwen3/2.5 model < 7B, tie_word_embeddings = true + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 28); + + LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] { + return args->hidden_size() / args->n_heads(); + }); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); + +} // namespace xllm diff --git a/xllm/models/llm/common/qwen3_moe.h b/xllm/models/llm/common/qwen3_moe.h new file mode 100644 index 000000000..ed59f73c0 --- /dev/null +++ b/xllm/models/llm/common/qwen3_moe.h @@ -0,0 +1,315 @@ +/* Copyright 2025 The xLLM Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + https://github.com/jd-opensource/xllm/blob/main/LICENSE + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#pragma once + +#include + +#include "core/framework/model_context.h" +#include "core/layers/common/layer_utils.h" +#include "core/layers/qwen3_moe_decoder_layer.h" +#include "llm_model_base.h" + +namespace xllm { + +using torch::indexing::None; +using ISlice = torch::indexing::Slice; + +class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { + public: + Qwen3MoeDecoderLayerImpl(const ModelContext& context, const int32_t i) { + // register submodules + decoder_layer_ = register_module("decoder_layer", + layer::Qwen3MoeDecoderLayer(context, i)); + } + + torch::Tensor forward(torch::Tensor& x, + torch::Tensor& positions, + const layer::AttentionMetadata& attn_metadata, + KVCache& kv_cache, + const ModelInputParams& input_params) { + return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); + } + + void load_state_dict(const StateDict& state_dict) { + auto experts_state_dict = state_dict.get_dict_with_prefix("mlp.experts."); + auto fused_gate_up = experts_state_dict.get_tensor("gate_up_proj"); + auto fused_down = experts_state_dict.get_tensor("down_proj"); + + bool is_fused = fused_gate_up.defined() && fused_down.defined(); + + if (is_fused) { + torch::Tensor expert_gate_up = fused_gate_up; + torch::Tensor expert_down = fused_down; + + const int num_experts = expert_gate_up.size(0); + + auto chunks = expert_gate_up.chunk(2, /*dim=*/-1); + auto expert_gate = chunks[0].contiguous(); + auto expert_up = chunks[1].contiguous(); + + std::unordered_map out_state_dict; + for (const auto& [name, tensor] : state_dict) { + if (name.find("self_attn.") == 0 || name.find("mlp.gate.") == 0 || + name.find("input_layernorm.") == 0 || + name.find("post_attention_layernorm.") == 0) { + out_state_dict.emplace(name, tensor); + } + } + + for (int i = 0; i < num_experts; ++i) { + auto gate_i = expert_gate[i].transpose(0, 1); + auto up_i = expert_up[i].transpose(0, 1); + auto down_i = expert_down[i].transpose(0, 1); + + const std::string base = "mlp.experts." + std::to_string(i) + "."; + out_state_dict.emplace(base + "gate_proj.weight", gate_i); + out_state_dict.emplace(base + "up_proj.weight", up_i); + out_state_dict.emplace(base + "down_proj.weight", down_i); + } + decoder_layer_->load_state_dict(StateDict(std::move(out_state_dict))); + } else { + decoder_layer_->load_state_dict(state_dict); + } + } + + private: + layer::Qwen3MoeDecoderLayer decoder_layer_{nullptr}; +}; +TORCH_MODULE(Qwen3MoeDecoderLayer); + +class Qwen3MoeModelImpl : public torch::nn::Module { + public: + Qwen3MoeModelImpl(const ModelContext& context) + : device_(context.get_tensor_options().device()) { + auto options = context.get_tensor_options(); + auto model_args = context.get_model_args(); + auto parallel_args = context.get_parallel_args(); + mrope_section_ = model_args.rope_scaling_mrope_section(); + blocks_ = register_module("layers", torch::nn::ModuleList()); + layers_.reserve(model_args.n_layers()); + // register submodules + device_ = options.device(); + dtype_ = options.dtype().toScalarType(); + num_speculative_tokens_ = model_args.num_speculative_tokens(); + embed_tokens_ = + register_module("embed_tokens", layer::WordEmbedding(context)); + + max_seq_len_ = model_args.max_position_embeddings(); + norm_ = register_module("norm", layer::RmsNorm(context)); + mapping_data_ = parallel_args.mapping_data(); + + for (int32_t i = 0; i < model_args.n_layers(); ++i) { + auto block = Qwen3MoeDecoderLayer(context, i); + layers_.push_back(block); + blocks_->push_back(block); + } + + dp_size_ = parallel_args.dp_size(); + std::vector indices; + dp_local_tp_size_ = parallel_args.world_size() / dp_size_; + dp_rank_ = parallel_args.rank() / dp_local_tp_size_; + rank_ = parallel_args.rank(); + num_experts_per_tok_ = model_args.num_experts_per_tok(); + for (int i = 0; i < parallel_args.world_size(); i += dp_local_tp_size_) { + indices.push_back(i); + } + } + + torch::Tensor deepstack_process(torch::Tensor hidden_states, + torch::Tensor visual_pos_masks, + torch::Tensor visual_embeds) { + visual_pos_masks = visual_pos_masks.to(hidden_states.device()); + auto selected = hidden_states.index({visual_pos_masks}); + auto local_this = selected + visual_embeds; + hidden_states.index_put_({visual_pos_masks}, local_this); + return hidden_states; + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + torch::Tensor forward(torch::Tensor tokens, + torch::Tensor positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + if (dp_size_ > 1) { + if (tokens.sizes() == 0) { + tokens = torch::tensor({1}).to(torch::kInt32).to(device_); + positions = torch::tensor({0}).to(torch::kInt32).to(device_); + } + } + + ModelInputParams modified_input_params = input_params; + layer::update_dummy_run_input(dp_rank_, positions, modified_input_params); + bool is_prefill = modified_input_params.q_max_seq_len > 1; + auto attn_metadata = + layer::AttentionMetadata::build(modified_input_params, is_prefill); + torch::Tensor h = embed_tokens_(tokens); + for (size_t i = 0; i < layers_.size(); i++) { + auto& layer = layers_[i]; + h = layer( + h, positions, attn_metadata, kv_caches[i], modified_input_params); + } + return norm_(h); + } + + // load the weight from the checkpoint + void load_state_dict(const StateDict& state_dict) { + embed_tokens_->load_state_dict( + state_dict.get_dict_with_prefix("embed_tokens.")); + // call each layer's load_state_dict function + for (int i = 0; i < layers_.size(); i++) { + layers_[i]->load_state_dict( + state_dict.get_dict_with_prefix("layers." + std::to_string(i) + ".")); + } + norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); + } + + layer::WordEmbedding get_word_embedding() { return embed_tokens_; } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + embed_tokens_ = word_embedding; + } + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return embed_tokens_(input_ids); + } + + private: + torch::nn::ModuleList blocks_{nullptr}; + std::vector layers_; + int32_t max_seq_len_ = 0; + int32_t dp_rank_; + int32_t rank_; + int32_t dp_size_; + int32_t dp_local_tp_size_; + nlohmann::json mapping_data_; + int32_t num_experts_per_tok_; + int32_t num_speculative_tokens_ = 0; + at::Device device_; + torch::Dtype dtype_; + layer::WordEmbedding embed_tokens_{nullptr}; + layer::AttentionMask attn_mask_; + layer::RmsNorm norm_{nullptr}; + std::vector mrope_section_; +}; +TORCH_MODULE(Qwen3MoeModel); + +class Qwen3MoeForCausalLMImpl : public torch::nn::Module { + public: + Qwen3MoeForCausalLMImpl(const ModelContext& context) { + model_ = register_module("model", Qwen3MoeModel(context)); + lm_head_ = register_module("lm_head", layer::LmHead(context)); + } + + // tokens: [num_tokens] + // positions: [num_tokens] token pos in the sequence + // returns: [num_tokens, hidden_size] + torch::Tensor forward(const torch::Tensor& tokens, + const torch::Tensor& positions, + std::vector& kv_caches, + const ModelInputParams& input_params) { + return model_(tokens, positions, kv_caches, input_params); + } + + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // returns: [num_tokens, vocab_size] + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes) { + // select tokens if provided + auto h = hidden_states; + if (seleted_idxes.defined()) { + h = h.index_select(/*dim=*/0, seleted_idxes); + } + return lm_head_(h); + } + + torch::Tensor get_input_embeddings(torch::Tensor input_ids) { + return model_->get_input_embeddings(input_ids); + } + + void load_model(std::unique_ptr loader, + std::string prefix = "model." /*llm model weight prefix*/) { + for (const auto& state_dict : loader->get_state_dicts()) { + model_->load_state_dict(state_dict->get_dict_with_prefix(prefix)); + lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); + } + } + + virtual void prepare_expert_weight(int32_t layer_id, + const std::vector& expert_ids) { + return; + } + virtual void update_expert_weight(int32_t layer_id) { return; } + + layer::LmHead get_lm_head() { return lm_head_; } + + void set_lm_head(layer::LmHead& head) { lm_head_ = head; } + + layer::WordEmbedding get_word_embedding() { + return model_->get_word_embedding(); + } + + void set_word_embedding(layer::WordEmbedding& word_embedding) { + model_->set_word_embedding(word_embedding); + } + + private: + Qwen3MoeModel model_{nullptr}; + layer::LmHead lm_head_{nullptr}; +}; +TORCH_MODULE(Qwen3MoeForCausalLM); + +// register the causal model +REGISTER_CAUSAL_MODEL(qwen3_moe, Qwen3MoeForCausalLM); + +// register the model args +// example config: +// https://huggingface.co/Qwen/Qwen3-30B-A3B/blob/main/config.json +// https://huggingface.co/Qwen/Qwen3-235B-A22B/blob/main/config.json +REGISTER_MODEL_ARGS(qwen3_moe, [&] { + LOAD_ARG_OR(model_type, "model_type", "qwen3_moe"); + LOAD_ARG_OR(dtype, "torch_dtype", ""); + LOAD_ARG_OR(attention_bias, "attention_bias", false); + LOAD_ARG_OR(attention_dropout, "attention_dropout", 0.0f); + LOAD_ARG_OR(bos_token_id, "bos_token_id", 151643); + LOAD_ARG_OR(decoder_sparse_step, "decoder_sparse_step", 1); + LOAD_ARG_OR(eos_token_id, "eos_token_id", 151645); + LOAD_ARG_OR(head_dim, "head_dim", 128); + LOAD_ARG_OR(hidden_act, "hidden_act", "silu"); + LOAD_ARG_OR(hidden_size, "hidden_size", 2048); + LOAD_ARG_OR(initializer_range, "initializer_range", 0.02f); + LOAD_ARG_OR(intermediate_size, "intermediate_size", 6144); + LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 40960); + LOAD_ARG_OR(max_window_layers, "max_window_layers", 48); + LOAD_ARG_OR(moe_intermediate_size, "moe_intermediate_size", 768); + LOAD_ARG_OR(norm_topk_prob, "norm_topk_prob", true); + LOAD_ARG_OR(n_heads, "num_attention_heads", 32); + LOAD_ARG_OR(num_experts, "num_experts", 128); + LOAD_ARG_OR(num_experts_per_tok, "num_experts_per_tok", 8); + LOAD_ARG_OR(n_layers, "num_hidden_layers", 48); + LOAD_ARG_OR(n_kv_heads, "num_key_value_heads", 4); + LOAD_ARG_OR(output_router_logits, "output_router_logits", false); + LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6); + LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f); + LOAD_ARG_OR(router_aux_loss_coef, "router_aux_loss_coef", 0.001f); + LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false); + LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false); + LOAD_ARG_OR(vocab_size, "vocab_size", 151936); + LOAD_ARG_OR(mlp_only_layers, "mlp_only_layers", std::vector()); + + SET_ARG(stop_token_ids, std::unordered_set({args->eos_token_id()})); +}); +} // namespace xllm diff --git a/xllm/models/llm/llm_model_base.h b/xllm/models/llm/llm_model_base.h index 441b52004..0ffa62d2d 100644 --- a/xllm/models/llm/llm_model_base.h +++ b/xllm/models/llm/llm_model_base.h @@ -15,9 +15,7 @@ limitations under the License. #pragma once -#if defined(USE_NPU) #include -#endif #include #include @@ -36,17 +34,7 @@ limitations under the License. #include "core/layers/pos_embedding.h" #include "core/layers/rms_norm.h" #include "models/model_registry.h" -#if defined(USE_NPU) #include "xllm_kernels/core/include/atb_speed/log.h" -#else -#include "core/layers/common/layer_utils.h" -#endif -#if defined(USE_CUDA) -#include "core/layers/cuda/attention.h" -#endif -#if defined(USE_MLU) -#include "core/layers/mlu/attention.h" -#endif namespace xllm { @@ -87,12 +75,9 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { LlmDecoderLayerImplBase(const ModelContext& context) { // register submodules decoder_layer_ = register_module("decoder_layer", DecoderType(context)); -#if defined(USE_NPU) block_copy_ = register_module("block_copy", layer::BlockCopy(context)); -#endif } -#if defined(USE_NPU) virtual torch::Tensor forward(torch::Tensor& x, torch::Tensor& cos_pos, torch::Tensor& sin_pos, @@ -129,15 +114,6 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { decoder_layer_->merge_loaded_weights(); block_copy_->merge_loaded_weights(); } -#else - virtual torch::Tensor forward(torch::Tensor& x, - torch::Tensor& positions, - const layer::AttentionMetadata& attn_metadata, - KVCache& kv_cache, - const ModelInputParams& input_params) { - return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); - } -#endif // load the weight from the checkpoint virtual void load_state_dict(const StateDict& state_dict) { @@ -147,9 +123,7 @@ class LlmDecoderLayerImplBase : public torch::nn::Module { private: DecoderType decoder_layer_{nullptr}; -#if defined(USE_NPU) layer::BlockCopy block_copy_{nullptr}; -#endif }; template @@ -165,11 +139,7 @@ class LlmModelImplBase : public torch::nn::Module { } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { -#if defined(USE_NPU) return embed_tokens_(input_ids, 0); -#else - return embed_tokens_(input_ids); -#endif } // tokens: [num_tokens] @@ -188,14 +158,9 @@ class LlmModelImplBase : public torch::nn::Module { if (inputs_embeds.defined()) { h = inputs_embeds; } else { -#if defined(USE_NPU) h = embed_tokens_(tokens, 0); -#else - h = embed_tokens_(tokens); -#endif } -#if defined(USE_NPU) auto target_cos_sin = atb_pos_emb_(cos_sin_, positions, 0); auto target_cos_sin_chunks = target_cos_sin.chunk(/*chunks=*/2, /*dim=*/-1); auto cos_pos = target_cos_sin_chunks[0].contiguous(); @@ -257,9 +222,7 @@ class LlmModelImplBase : public torch::nn::Module { max_seq_len_, cos_pos.dtype().toScalarType(), cos_pos.device()); } } -#endif -#if defined(USE_NPU) for (size_t i = 0; i < layers_.size(); i++) { aclrtEvent* event = nullptr; std::atomic* event_flag = nullptr; @@ -292,22 +255,6 @@ class LlmModelImplBase : public torch::nn::Module { } return norm_(h, 0); -#else - auto modified_input_params = input_params; - auto position = positions; - layer::update_dummy_run_input(dp_rank_, position, modified_input_params); - bool is_prefill = modified_input_params.q_max_seq_len > 1; - auto attn_metadata = - layer::AttentionMetadata::build(modified_input_params, is_prefill); - - torch::Tensor h_ret; - for (size_t i = 0; i < layers_.size(); i++) { - auto& layer = layers_[i]; - h_ret = layer( - h, position, attn_metadata, kv_caches[i], modified_input_params); - } - return norm_(h_ret); -#endif } // load the weight from the checkpoint @@ -323,7 +270,6 @@ class LlmModelImplBase : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } -#if defined(USE_NPU) virtual void verify_loaded_weights(const std::string& prefix) const { embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); @@ -342,7 +288,6 @@ class LlmModelImplBase : public torch::nn::Module { } norm_->merge_loaded_weights(); } -#endif virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; } @@ -358,9 +303,7 @@ class LlmModelImplBase : public torch::nn::Module { int device_id = 0; layer::AttentionMask attn_mask_; int dp_rank_ = 0; -#if defined(USE_NPU) layer::PosEmbedding atb_pos_emb_{nullptr}; -#endif std::vector mrope_section_; // test @@ -410,15 +353,7 @@ class LlmForCausalLMImplBase : public torch::nn::Module { const torch::Tensor& seleted_idxes) { // select tokens if provided auto h = hidden_states; - // test -#if defined(USE_NPU) return lm_head_(hidden_states, seleted_idxes, 0); -#else - if (seleted_idxes.defined()) { - h = h.index_select(/*dim=*/0, seleted_idxes); - } - return lm_head_(h); -#endif } void load_model(std::unique_ptr loader, @@ -432,15 +367,11 @@ class LlmForCausalLMImplBase : public torch::nn::Module { lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } } -#if defined(USE_NPU) - // verify model_->verify_loaded_weights(prefix); lm_head_->verify_loaded_weights("lm_head."); model_->merge_loaded_weights(); - // test lm_head_->merge_loaded_weights(); -#endif } virtual void prepare_expert_weight(int32_t layer_id, @@ -466,7 +397,6 @@ class LlmForCausalLMImplBase : public torch::nn::Module { LlmModelType model_{nullptr}; int device_id = 0; bool tie_word_embeddings{false}; - // test layer::LmHead lm_head_{nullptr}; }; diff --git a/xllm/models/llm/qwen2.h b/xllm/models/llm/qwen2.h index af5bfacf4..fe3cae507 100644 --- a/xllm/models/llm/qwen2.h +++ b/xllm/models/llm/qwen2.h @@ -49,9 +49,7 @@ class QWen2ModelImpl : public LlmModelImplBase { norm_ = register_module("norm", layer::RmsNorm(context)); embed_tokens_ = register_module("embed_tokens", layer::WordEmbedding(context)); -#if defined(USE_NPU) atb_pos_emb_ = layer::PosEmbedding(context); -#endif cos_sin_ = get_concat_rotary_embedding( model_args.hidden_size() / model_args.n_heads(), model_args.max_position_embeddings(), diff --git a/xllm/models/llm/qwen3_moe.h b/xllm/models/llm/qwen3_moe.h index dec5d0159..59876b228 100644 --- a/xllm/models/llm/qwen3_moe.h +++ b/xllm/models/llm/qwen3_moe.h @@ -17,10 +17,8 @@ limitations under the License. #include -#include -#if defined(USE_NPU) +// #include #include "core/framework/model/npu_dp_ep_padding.h" -#endif #include "core/framework/model_context.h" #include "core/layers/common/layer_utils.h" #include "core/layers/qwen3_moe_decoder_layer.h" @@ -39,7 +37,6 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { layer::Qwen3MoeDecoderLayer(context, i)); } -#if defined(USE_NPU) torch::Tensor forward(torch::Tensor x, torch::Tensor cos_pos, torch::Tensor sin_pos, @@ -59,15 +56,7 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { event, event_flag); } -#else - torch::Tensor forward(torch::Tensor& x, - torch::Tensor& positions, - const layer::AttentionMetadata& attn_metadata, - KVCache& kv_cache, - const ModelInputParams& input_params) { - return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params); - } -#endif + void load_state_dict(const StateDict& state_dict) { auto experts_state_dict = state_dict.get_dict_with_prefix("mlp.experts."); auto fused_gate_up = experts_state_dict.get_tensor("gate_up_proj"); @@ -110,13 +99,11 @@ class Qwen3MoeDecoderLayerImpl : public torch::nn::Module { } } -#if defined(USE_NPU) void verify_loaded_weights(const std::string& prefix) const { decoder_layer_->verify_loaded_weights(prefix); } void merge_loaded_weights() { decoder_layer_->merge_loaded_weights(); } -#endif private: layer::Qwen3MoeDecoderLayer decoder_layer_{nullptr}; @@ -155,13 +142,11 @@ class Qwen3MoeModelImpl : public torch::nn::Module { options); max_seq_len_ = model_args.max_position_embeddings(); -#if defined(USE_NPU) atb_pos_emb_ = layer::PosEmbedding(context); int32_t mask_value = model_args.dtype() == "bfloat16" ? 1 : -9984; attn_mask_ = layer::AttentionMask(options.device(), options.dtype().toScalarType(), /*mask_value=*/mask_value); -#endif norm_ = register_module("norm", layer::RmsNorm(context)); mapping_data_ = parallel_args.mapping_data(); @@ -204,7 +189,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { positions = torch::tensor({0}).to(torch::kInt32).to(device_); } } -#if defined(USE_NPU) + auto inputs_embeds = input_params.input_embedding; torch::Tensor h; if (inputs_embeds.defined()) { @@ -284,20 +269,6 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } } return norm_(h, 0); -#else - ModelInputParams modified_input_params = input_params; - layer::update_dummy_run_input(dp_rank_, positions, modified_input_params); - bool is_prefill = modified_input_params.q_max_seq_len > 1; - auto attn_metadata = - layer::AttentionMetadata::build(modified_input_params, is_prefill); - torch::Tensor h = embed_tokens_(tokens); - for (size_t i = 0; i < layers_.size(); i++) { - auto& layer = layers_[i]; - h = layer( - h, positions, attn_metadata, kv_caches[i], modified_input_params); - } - return norm_(h); -#endif } // load the weight from the checkpoint @@ -312,7 +283,6 @@ class Qwen3MoeModelImpl : public torch::nn::Module { norm_->load_state_dict(state_dict.get_dict_with_prefix("norm.")); } -#if defined(USE_NPU) void verify_loaded_weights(const std::string& prefix) const { embed_tokens_->verify_loaded_weights(prefix + "embed_tokens."); for (int i = 0; i < layers_.size(); i++) { @@ -329,7 +299,6 @@ class Qwen3MoeModelImpl : public torch::nn::Module { } norm_->merge_loaded_weights(); } -#endif layer::WordEmbedding get_word_embedding() { return embed_tokens_; } @@ -337,13 +306,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { embed_tokens_ = word_embedding; } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { -#if defined(USE_NPU) return embed_tokens_(input_ids, 0); -#elif defined(USE_MLU) - return embed_tokens_(input_ids); -#else - LOG(FATAL) << "Backend not supported: enable USE_NPU or USE_MLU."; -#endif } private: @@ -363,9 +326,7 @@ class Qwen3MoeModelImpl : public torch::nn::Module { layer::AttentionMask attn_mask_; layer::RmsNorm norm_{nullptr}; torch::Tensor cos_sin_; -#if defined(USE_NPU) layer::PosEmbedding atb_pos_emb_{nullptr}; -#endif std::vector mrope_section_; }; TORCH_MODULE(Qwen3MoeModel); @@ -392,16 +353,7 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { // returns: [num_tokens, vocab_size] torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) { -#if defined(USE_NPU) return lm_head_(hidden_states, seleted_idxes, 0); -#else - // select tokens if provided - auto h = hidden_states; - if (seleted_idxes.defined()) { - h = h.index_select(/*dim=*/0, seleted_idxes); - } - return lm_head_(h); -#endif } torch::Tensor get_input_embeddings(torch::Tensor input_ids) { @@ -415,14 +367,11 @@ class Qwen3MoeForCausalLMImpl : public torch::nn::Module { lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head.")); } -#if defined(USE_NPU) - // verify model_->verify_loaded_weights(prefix); lm_head_->verify_loaded_weights("lm_head."); model_->merge_loaded_weights(); lm_head_->merge_loaded_weights(); -#endif } virtual void prepare_expert_weight(int32_t layer_id, diff --git a/xllm/models/models.h b/xllm/models/models.h index 0460d6ff5..1d4286c6b 100644 --- a/xllm/models/models.h +++ b/xllm/models/models.h @@ -15,11 +15,6 @@ limitations under the License. #pragma once -#include "llm/llm_model_base.h" // IWYU pragma: keep -#include "llm/qwen2.h" // IWYU pragma: keep -#include "llm/qwen3.h" // IWYU pragma: keep -#include "llm/qwen3_moe.h" // IWYU pragma: keep - #if defined(USE_NPU) #include "dit/pipeline_flux.h" // IWYU pragma: keep #include "dit/pipeline_flux_control.h" // IWYU pragma: keep @@ -32,13 +27,22 @@ limitations under the License. #include "llm/kimi_k2.h" // IWYU pragma: keep #include "llm/llama.h" // IWYU pragma: keep #include "llm/llama3.h" // IWYU pragma: keep +#include "llm/llm_model_base.h" // IWYU pragma: keep +#include "llm/qwen2.h" // IWYU pragma: keep +#include "llm/qwen3.h" // IWYU pragma: keep #include "llm/qwen3_embedding.h" // IWYU pragma: keep +#include "llm/qwen3_moe.h" // IWYU pragma: keep #include "vlm/minicpmv.h" // IWYU pragma: keep #include "vlm/qwen2_5_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl.h" // IWYU pragma: keep #include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep #elif defined(USE_MLU) -#include "llm/mlu/deepseek_v2.h" // IWYU pragma: keep -#include "llm/mlu/deepseek_v3.h" // IWYU pragma: keep -#include "llm/mlu/deepseek_v32.h" // IWYU pragma: keep +#include "llm/common/deepseek_v2.h" // IWYU pragma: keep +#include "llm/common/deepseek_v3.h" // IWYU pragma: keep +#include "llm/common/deepseek_v32.h" // IWYU pragma: keep +#else +#include "llm/common/llm_model_base.h" // IWYU pragma: keep +#include "llm/common/qwen2.h" // IWYU pragma: keep +#include "llm/common/qwen3.h" // IWYU pragma: keep +#include "llm/common/qwen3_moe.h" // IWYU pragma: keep #endif