Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion xllm/core/layers/common/indexer.h
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
#include "../mlu/attention.h"
#elif defined(USE_CUDA)
#include "../cuda/attention.h"
#endif #include "framework/kv_cache/kv_cache.h"
#endif
#include "framework/model/model_input_params.h"
#include "framework/parallel_state/parallel_args.h"
#include "framework/quant_args.h"
Expand Down
1 change: 0 additions & 1 deletion xllm/models/CMakeLists.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
include(cc_library)

# Define the library
cc_library(
NAME
models
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ limitations under the License.
#include <vector>

#include "core/layers/deepseek_v2_decoder_layer.h"
#include "models/llm/llm_model_base.h"
#include "llm_model_base.h"

// DeepSeek v2 compatible with huggingface weights
// ref to:
Expand Down
File renamed without changes.
240 changes: 240 additions & 0 deletions xllm/models/llm/common/llm_model_base.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,240 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/torch.h>

#include <string>
#include <typeinfo>
#include <vector>

#include "core/common/global_flags.h"
#include "core/common/interruption_bus.h"
#include "core/framework/kv_cache/kv_cache.h"
#include "core/framework/model/model_input_params.h"
#include "core/framework/model_context.h"
#include "core/layers/attention_mask.h"
#include "core/layers/common/layer_utils.h"
#include "core/layers/lm_head.h"
#include "core/layers/rms_norm.h"
#include "models/model_registry.h"
#if defined(USE_CUDA)
#include "core/layers/cuda/attention.h"
#endif
#if defined(USE_MLU)
#include "core/layers/mlu/attention.h"
#endif

namespace xllm {

template <typename DecoderType>
class LlmDecoderLayerImplBase : public torch::nn::Module {
public:
LlmDecoderLayerImplBase(const ModelContext& context) {
// register submodules
decoder_layer_ = register_module("decoder_layer", DecoderType(context));
}

virtual torch::Tensor forward(torch::Tensor& x,
torch::Tensor& positions,
const layer::AttentionMetadata& attn_metadata,
KVCache& kv_cache,
const ModelInputParams& input_params) {
return decoder_layer_(x, positions, attn_metadata, kv_cache, input_params);
}

// load the weight from the checkpoint
virtual void load_state_dict(const StateDict& state_dict) {
// call each submodule's load_state_dict function
decoder_layer_->load_state_dict(state_dict);
}

private:
DecoderType decoder_layer_{nullptr};
};

template <typename DecoderLayerType>
class LlmModelImplBase : public torch::nn::Module {
public:
// mode type: qwen2, qwen3 .etc
LlmModelImplBase(const std::string& model_type, const ModelArgs& args)
: model_type_(model_type) {
InterruptionBus::get_instance().subscribe([this](bool interrupted) {
this->layer_forward_interrupted_ = interrupted;
});
mrope_section_ = args.rope_scaling_mrope_section();
}

torch::Tensor get_input_embeddings(torch::Tensor input_ids) {
return embed_tokens_(input_ids);
}

// tokens: [num_tokens]
// positions: [num_tokens] token pos in the sequence
virtual torch::Tensor forward(torch::Tensor tokens,
torch::Tensor positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
if (tokens.numel() == 0) {
tokens = torch::tensor({1}).to(torch::kInt32).to(tokens.device());
positions = torch::tensor({0}).to(torch::kInt32).to(tokens.device());
}
auto inputs_embeds = input_params.input_embedding;
// test
torch::Tensor h;
if (inputs_embeds.defined()) {
h = inputs_embeds;
} else {
h = embed_tokens_(tokens);
}

auto modified_input_params = input_params;
auto position = positions;
layer::update_dummy_run_input(dp_rank_, position, modified_input_params);
bool is_prefill = modified_input_params.q_max_seq_len > 1;
auto attn_metadata =
layer::AttentionMetadata::build(modified_input_params, is_prefill);

torch::Tensor h_ret;
for (size_t i = 0; i < layers_.size(); i++) {
auto& layer = layers_[i];
h_ret = layer(
h, position, attn_metadata, kv_caches[i], modified_input_params);
}
return norm_(h_ret);
}

// load the weight from the checkpoint
virtual void load_state_dict(const StateDict& state_dict) {
embed_tokens_->load_state_dict(
state_dict.get_dict_with_prefix("embed_tokens."));

// call each layer's load_state_dict function
for (int i = 0; i < layers_.size(); i++) {
layers_[i]->load_state_dict(
state_dict.get_dict_with_prefix("layers." + std::to_string(i) + "."));
}
norm_->load_state_dict(state_dict.get_dict_with_prefix("norm."));
}

virtual layer::WordEmbedding get_word_embedding() { return embed_tokens_; }

virtual void set_word_embedding(layer::WordEmbedding& word_embedding) {
embed_tokens_ = word_embedding;
}

protected:
int max_seq_len_ = 0;
torch::Tensor cos_pos_;
torch::Tensor sin_pos_;
int device_id = 0;
layer::AttentionMask attn_mask_;
int dp_rank_ = 0;

std::vector<int64_t> mrope_section_;
// test
// ParallelEmbedding embed_tokens_{nullptr};
layer::WordEmbedding embed_tokens_{nullptr};
layer::RmsNorm norm_{nullptr};

torch::nn::ModuleList blocks_{nullptr};
// hold same data but different type as blocks_ to avoid type cast
std::vector<DecoderLayerType> layers_;

bool layer_forward_interrupted_ = false;

private:
std::string model_type_;
};

template <typename LlmModelType>
class LlmForCausalLMImplBase : public torch::nn::Module {
public:
LlmForCausalLMImplBase(const ModelContext& context) {
tie_word_embeddings = context.get_model_args().tie_word_embeddings();
// register submodules
model_ = register_module("model", LlmModelType(context));

lm_head_ = register_module("lm_head", layer::LmHead(context));
}

torch::Tensor get_input_embeddings(torch::Tensor input_ids) {
return model_->get_input_embeddings(input_ids);
}

// tokens: [num_tokens]
// positions: [num_tokens] token pos in the sequence
// returns: [num_tokens, hidden_size]
virtual torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
return model_(tokens, positions, kv_caches, input_params);
}

// hidden_states: [num_tokens, hidden_size]
// seleted_idxes: [num_tokens]
// returns: [num_tokens, vocab_size]
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
// select tokens if provided
auto h = hidden_states;
if (seleted_idxes.defined()) {
h = h.index_select(/*dim=*/0, seleted_idxes);
}
return lm_head_(h);
}

void load_model(std::unique_ptr<ModelLoader> loader,
std::string prefix = "model." /*llm model weight prefix*/) {
for (const auto& state_dict : loader->get_state_dicts()) {
model_->load_state_dict(state_dict->get_dict_with_prefix(prefix));
if (tie_word_embeddings) {
lm_head_->load_state_dict(
state_dict->get_dict_with_prefix(prefix + "embed_tokens."));
} else {
lm_head_->load_state_dict(state_dict->get_dict_with_prefix("lm_head."));
}
}
}

virtual void prepare_expert_weight(int32_t layer_id,
const std::vector<int32_t>& expert_ids) {
return;
}
virtual void update_expert_weight(int32_t layer_id) { return; }

virtual layer::LmHead get_lm_head() { return lm_head_; }

virtual void set_lm_head(layer::LmHead& head) { lm_head_ = head; }

virtual layer::WordEmbedding get_word_embedding() {
return model_->get_word_embedding();
}

virtual void set_word_embedding(layer::WordEmbedding& word_embedding) {
model_->set_word_embedding(word_embedding);
}

protected:
// parameter members, must be registered
LlmModelType model_{nullptr};
int device_id = 0;
bool tie_word_embeddings{false};
layer::LmHead lm_head_{nullptr};
};

} // namespace xllm
110 changes: 110 additions & 0 deletions xllm/models/llm/common/qwen2.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,110 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Copyright 2024 The ScaleLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include "core/layers/qwen2_decoder_layer.h"
#include "llm_model_base.h"

// QWen2 model compatible with huggingface weights
// ref to:
// https://github.com/huggingface/transformers/blob/v4.43.3/src/transformers/models/qwen2/modeling_qwen2.py
namespace xllm {

class QWen2DecoderLayerImpl
: public LlmDecoderLayerImplBase<layer::Qwen2DecoderLayer> {
public:
QWen2DecoderLayerImpl(const ModelContext& context)
: LlmDecoderLayerImplBase<layer::Qwen2DecoderLayer>(context) {}
};
TORCH_MODULE(QWen2DecoderLayer);

class QWen2ModelImpl : public LlmModelImplBase<QWen2DecoderLayer> {
public:
QWen2ModelImpl(const ModelContext& context)
: LlmModelImplBase<QWen2DecoderLayer>("qwen2", context.get_model_args()) {
// register submodules
auto model_args = context.get_model_args();
auto options = context.get_tensor_options();
auto parallel_args = context.get_parallel_args();
auto dp_local_tp_size =
parallel_args.world_size() / parallel_args.dp_size();
dp_rank_ = parallel_args.rank() / dp_local_tp_size;

blocks_ = register_module("layers", torch::nn::ModuleList());
layers_.reserve(model_args.n_layers());
norm_ = register_module("norm", layer::RmsNorm(context));
embed_tokens_ =
register_module("embed_tokens", layer::WordEmbedding(context));
int32_t mask_value = FLAGS_enable_chunked_prefill ? -9984 : 1;
attn_mask_ = layer::AttentionMask(options.device(),
options.dtype().toScalarType(),
/*mask_value=*/mask_value);

for (int32_t i = 0; i < model_args.n_layers(); i++) {
auto block = QWen2DecoderLayer(context);
layers_.push_back(block);
blocks_->push_back(block);
}
}
};
TORCH_MODULE(QWen2Model);

class QWen2ForCausalLMImpl : public LlmForCausalLMImplBase<QWen2Model> {
public:
QWen2ForCausalLMImpl(const ModelContext& context)
: LlmForCausalLMImplBase<QWen2Model>(context) {}
};
TORCH_MODULE(QWen2ForCausalLM);

// register the causal model
REGISTER_CAUSAL_MODEL(qwen2, QWen2ForCausalLM);

// register the model args
// example config:
// https://huggingface.co/Qwen/Qwen2-7B-Instruct/blob/main/config.json
REGISTER_MODEL_ARGS(qwen2, [&] {
LOAD_ARG_OR(model_type, "model_type", "qwen2");
LOAD_ARG_OR(dtype, "torch_dtype", "");
LOAD_ARG_OR(vocab_size, "vocab_size", 152064);
LOAD_ARG_OR(hidden_size, "hidden_size", 3584);
LOAD_ARG_OR(n_layers, "num_hidden_layers", 28);
LOAD_ARG_OR(n_heads, "num_attention_heads", 28);
LOAD_ARG(n_kv_heads, "num_key_value_heads");
LOAD_ARG_OR(hidden_act, "hidden_act", "silu");
LOAD_ARG_OR(attention_bias, "attention_bias", true);
// LOAD_ARG_OR(no_bias, "no_bias", true);
LOAD_ARG_OR(intermediate_size, "intermediate_size", 18944);
LOAD_ARG_OR(max_position_embeddings, "max_position_embeddings", 32768);
LOAD_ARG_OR(rms_norm_eps, "rms_norm_eps", 1e-6);
LOAD_ARG_OR(eos_token_id, "eos_token_id", 151643);
LOAD_ARG_OR(rope_theta, "rope_theta", 1000000.0f);

// For Qwen2/2.5 model < 7B, tie_word_embeddings = true
LOAD_ARG_OR(tie_word_embeddings, "tie_word_embeddings", false);

LOAD_ARG_OR(use_sliding_window, "use_sliding_window", false);
LOAD_ARG_OR(sliding_window, "sliding_window", 4096);
LOAD_ARG_OR(max_window_layers, "max_window_layers", 28);

LOAD_ARG_OR_FUNC(head_dim, "head_dim", [&] {
return args->hidden_size() / args->n_heads();
});

SET_ARG(stop_token_ids, std::unordered_set<int32_t>({args->eos_token_id()}));
});

} // namespace xllm
Loading
Loading