Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/core/framework/model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ cc_library(
dit_model.h
embedding_lm.h
embedding_vlm.h
mm_embedding_vlm.h
model_args.h
npu_dp_ep_padding.h
model_input_params.h
Expand Down
86 changes: 86 additions & 0 deletions xllm/core/framework/model/mm_embedding_vlm.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,86 @@
/* Copyright 2025 The xLLM Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
https://github.com/jd-opensource/xllm/blob/main/LICENSE
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <c10/core/Device.h>
#include <torch/torch.h>

#include <vector>

#include "causal_vlm.h"
#include "core/framework/kv_cache/kv_cache.h"
#include "core/framework/quant_args.h"
#include "core/framework/state_dict/state_dict.h"
#include "model_args.h"
#include "model_input_params.h"

namespace xllm {

class MMEmbeddingVLM : public CausalVLM {
public:
~MMEmbeddingVLM() override = default;

virtual std::vector<torch::Tensor> encode(
const ModelInputParams& input_params) = 0;

virtual torch::Tensor logits(const torch::Tensor& hidden_states,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

no need to define these interface in MMEmbeddingVLM, move these to MMEmbeddingVLMImpl

const torch::Tensor& selected_idxes) {
return torch::Tensor();
}

virtual torch::Tensor forward(const torch::Tensor& tokens,
const torch::Tensor& positions,
std::vector<KVCache>& kv_caches,
const ModelInputParams& input_params) {
return torch::Tensor{};
}
virtual void prepare_expert_weight(int32_t layer_id,
const std::vector<int32_t>& expert_ids) {
return;
}
virtual void update_expert_weight(int32_t layer_id) { return; }
virtual void set_lm_head(layer::LmHead& head) { return; }
virtual layer::LmHead get_lm_head() { return nullptr; }
virtual layer::WordEmbedding get_word_embedding() { return nullptr; }
virtual void set_word_embedding(layer::WordEmbedding& embedding) { return; }
};

template <typename Model>
class MMEmbeddingVLMImpl : public MMEmbeddingVLM {
public:
MMEmbeddingVLMImpl(Model model, const torch::TensorOptions& options)
: model_(std::move(model)), options_(options) {}

virtual std::vector<torch::Tensor> encode(
const ModelInputParams& input_params) override {
return model_->encode(input_params);
};

void load_model(std::unique_ptr<ModelLoader> loader) override {
model_->load_model(std::move(loader));
}

torch::Device device() const override { return options_.device(); }

const torch::TensorOptions& options() const override { return options_; }

private:
Model model_;

torch::TensorOptions options_;
};

} // namespace xllm
36 changes: 36 additions & 0 deletions xllm/models/model_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,20 @@ void ModelRegistry::register_vlm_embedding_factory(
}
}

void ModelRegistry::register_vlm_mm_embedding_factory(
const std::string& name,
MMEmbeddingVLMFactory factory) {
ModelRegistry* instance = get_instance();

if (instance->model_registry_[name].mm_embedding_vlm_factory != nullptr) {
SAFE_LOG_WARNING("mm embedding vlm factory for " << name
<< " already registered.");
} else {
instance->model_registry_[name].mm_embedding_vlm_factory = factory;
instance->model_backend_[name] = "vlm";
}
}

void ModelRegistry::register_dit_model_factory(const std::string& name,
DiTModelFactory factory) {
ModelRegistry* instance = get_instance();
Expand Down Expand Up @@ -216,6 +230,13 @@ EmbeddingVLMFactory ModelRegistry::get_embeddingvlm_factory(
return instance->model_registry_[name].embedding_vlm_factory;
}

MMEmbeddingVLMFactory ModelRegistry::get_mm_embedding_vlm_factory(
const std::string& name) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe need yo rename :(

register_vlm_mm_embedding_factory ----> vlm_mm_embedding

and

get_mm_embedding_vlm_factory ----> mm_embedding_vlm

ModelRegistry* instance = get_instance();

return instance->model_registry_[name].mm_embedding_vlm_factory;
}

DiTModelFactory ModelRegistry::get_dit_model_factory(const std::string& name) {
ModelRegistry* instance = get_instance();
return instance->model_registry_[name].dit_model_factory;
Expand Down Expand Up @@ -317,6 +338,21 @@ std::unique_ptr<EmbeddingVLM> create_vlm_embedding_model(
return nullptr;
}

std::unique_ptr<MMEmbeddingVLM> create_vlm_mm_embedding_model(
const ModelContext& context) {
// get the factory function for the model type from model registry
auto factory = ModelRegistry::get_mm_embedding_vlm_factory(
context.get_model_args().model_type());
if (factory) {
return factory(context);
}

LOG(ERROR) << "Unsupported model type: "
<< context.get_model_args().model_type();

return nullptr;
}

std::unique_ptr<DiTModel> create_dit_model(const DiTModelContext& context) {
// get the factory function for the model type from model registry
auto factory = ModelRegistry::get_dit_model_factory(context.model_type());
Expand Down
30 changes: 30 additions & 0 deletions xllm/models/model_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ limitations under the License.
#include "core/framework/model/dit_model.h"
#include "core/framework/model/embedding_lm.h"
#include "core/framework/model/embedding_vlm.h"
#include "core/framework/model/mm_embedding_vlm.h"
#include "core/framework/model_context.h"
#include "core/framework/tokenizer/tokenizer_args.h"
#include "core/util/json_reader.h"
Expand All @@ -47,6 +48,9 @@ using EmbeddingLMFactory =
using EmbeddingVLMFactory =
std::function<std::unique_ptr<EmbeddingVLM>(const ModelContext& context)>;

using MMEmbeddingVLMFactory =
std::function<std::unique_ptr<MMEmbeddingVLM>(const ModelContext& context)>;

using DiTModelFactory =
std::function<std::unique_ptr<DiTModel>(const DiTModelContext& context)>;

Expand All @@ -71,6 +75,7 @@ struct ModelMeta {
CausalVLMFactory causal_vlm_factory;
EmbeddingLMFactory embedding_lm_factory;
EmbeddingVLMFactory embedding_vlm_factory;
MMEmbeddingVLMFactory mm_embedding_vlm_factory;
DiTModelFactory dit_model_factory;
InputProcessorFactory input_processor_factory;
ImageProcessorFactory image_processor_factory;
Expand All @@ -97,6 +102,9 @@ class ModelRegistry {
static void register_vlm_embedding_factory(const std::string& name,
EmbeddingVLMFactory factory);

static void register_vlm_mm_embedding_factory(const std::string& name,
MMEmbeddingVLMFactory factory);

static void register_dit_model_factory(const std::string& name,
DiTModelFactory factory);

Expand All @@ -122,6 +130,9 @@ class ModelRegistry {

static EmbeddingVLMFactory get_embeddingvlm_factory(const std::string& name);

static MMEmbeddingVLMFactory get_mm_embedding_vlm_factory(
const std::string& name);

static DiTModelFactory get_dit_model_factory(const std::string& name);

static ModelArgsLoader get_model_args_loader(const std::string& name);
Expand Down Expand Up @@ -153,6 +164,9 @@ std::unique_ptr<EmbeddingLM> create_lm_embedding_model(
std::unique_ptr<EmbeddingVLM> create_vlm_embedding_model(
const ModelContext& context);

std::unique_ptr<MMEmbeddingVLM> create_vlm_mm_embedding_model(
const ModelContext& context);

std::unique_ptr<DiTModel> create_dit_model(const DiTModelContext& context);

// Macro to register a model with the ModelRegistry
Expand Down Expand Up @@ -218,6 +232,22 @@ std::unique_ptr<DiTModel> create_dit_model(const DiTModelContext& context);
#define REGISTER_EMBEDDING_VLM_MODEL(ModelType, ModelClass) \
REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass)

#define REGISTER_MM_EMBEDDING_VLM_MODEL_WITH_VARNAME( \
VarName, ModelType, ModelClass) \
const bool VarName##_registered = []() { \
ModelRegistry::register_vlm_mm_embedding_factory( \
#ModelType, [](const ModelContext& context) { \
ModelClass model(context); \
model->eval(); \
return std::make_unique<xllm::MMEmbeddingVLMImpl<ModelClass>>( \
std::move(model), context.get_tensor_options()); \
}); \
return true; \
}()

#define REGISTER_MM_EMBEDDING_VLM_MODEL(ModelType, ModelClass) \
REGISTER_MM_EMBEDDING_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass)

#define REGISTER_DIT_MODEL_WITH_VARNAME(VarName, ModelType, ModelClass) \
const bool VarName##_registered = []() { \
ModelRegistry::register_dit_model_factory( \
Expand Down
Loading