Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion xllm/core/framework/model/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ cc_library(
causal_lm.h
causal_vlm.h
dit_model.h
embedding_vlm.h
mm_embedding_vlm.h
model_args.h
npu_cp_ep_padding.h
Expand Down
5 changes: 5 additions & 0 deletions xllm/core/framework/model/causal_vlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,11 @@ class CausalVLMImpl : public CausalVLM {
return model_->forward(tokens, positions, kv_caches, parameters);
}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) override {
return model_->pooler(hidden_states, seleted_idxes);
}

torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) override {
return model_->logits(hidden_states, seleted_idxes);
Expand Down
73 changes: 0 additions & 73 deletions xllm/core/framework/model/embedding_vlm.h

This file was deleted.

7 changes: 3 additions & 4 deletions xllm/core/runtime/embed_vlm_worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,8 @@ EmbedVLMWorkerImpl::EmbedVLMWorkerImpl(const ParallelArgs& parallel_args,
bool EmbedVLMWorkerImpl::init_model(ModelContext& context) {
CHECK(model_ == nullptr) << "Model is already initialized.";

context.set_image_embedding_mode(true);
model_ = create_vlm_embedding_model(context);
context.set_image_embedding_mode(false);
model_ = create_vlm_model(context);
CHECK(model_ != nullptr) << "Failed to create model.";
model_executor_ = std::make_unique<Executor>(
model_.get(), context.get_model_args(), device_, options_);
Expand Down Expand Up @@ -84,9 +84,8 @@ std::optional<ForwardOutput> EmbedVLMWorkerImpl::step(

if (sampling_params.selected_token_idxes.defined() &&
input.sampling_params.is_embeddings) {
EmbeddingVLM* em_model = dynamic_cast<EmbeddingVLM*>(model_.get());
auto embeddings =
em_model->pooler(hidden_states, sampling_params.selected_token_idxes);
model_->pooler(hidden_states, sampling_params.selected_token_idxes);
sample_output.embeddings = embeddings;
output.sample_output = sample_output;
output.embedding = embeddings;
Expand Down
10 changes: 10 additions & 0 deletions xllm/models/llm/npu/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,16 @@ class QWen2ForCausalLMImpl : public LlmForCausalLMImplBase<QWen2Model> {
public:
QWen2ForCausalLMImpl(const ModelContext& context)
: LlmForCausalLMImplBase<QWen2Model>(context) {}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
auto h = hidden_states;
if (seleted_idxes.defined()) {
h = h.index_select(/*dim=*/0, seleted_idxes);
}
return torch::nn::functional::normalize(
h, torch::nn::functional::NormalizeFuncOptions().p(2).dim(1));
}
};
TORCH_MODULE(QWen2ForCausalLM);

Expand Down
10 changes: 10 additions & 0 deletions xllm/models/llm/qwen2.h
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,16 @@ class QWen2ForCausalLMImpl : public LlmForCausalLMImplBase<QWen2Model> {
public:
QWen2ForCausalLMImpl(const ModelContext& context)
: LlmForCausalLMImplBase<QWen2Model>(context) {}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
auto h = hidden_states;
if (seleted_idxes.defined()) {
h = h.index_select(/*dim=*/0, seleted_idxes);
}
namespace F = torch::nn::functional;
return F::normalize(h, F::NormalizeFuncOptions().p(2).dim(1));
}
};
TORCH_MODULE(QWen2ForCausalLM);

Expand Down
43 changes: 0 additions & 43 deletions xllm/models/model_registry.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -200,20 +200,6 @@ void ModelRegistry::register_causalvlm_factory(const std::string& name,
}
}

void ModelRegistry::register_vlm_embedding_factory(
const std::string& name,
EmbeddingVLMFactory factory) {
ModelRegistry* instance = get_instance();

if (instance->model_registry_[name].embedding_vlm_factory != nullptr) {
SAFE_LOG_WARNING("embedding vlm factory for " << name
<< " already registered.");
} else {
instance->model_registry_[name].embedding_vlm_factory = factory;
instance->model_backend_[name] = "vlm";
}
}

void ModelRegistry::register_mm_embedding_vlm_factory(
const std::string& name,
MMEmbeddingVLMFactory factory) {
Expand Down Expand Up @@ -321,13 +307,6 @@ CausalVLMFactory ModelRegistry::get_causalvlm_factory(const std::string& name) {
return instance->model_registry_[name].causal_vlm_factory;
}

EmbeddingVLMFactory ModelRegistry::get_embeddingvlm_factory(
const std::string& name) {
ModelRegistry* instance = get_instance();

return instance->model_registry_[name].embedding_vlm_factory;
}

MMEmbeddingVLMFactory ModelRegistry::get_mm_embedding_vlm_factory(
const std::string& name) {
ModelRegistry* instance = get_instance();
Expand Down Expand Up @@ -441,28 +420,6 @@ std::unique_ptr<CausalVLM> create_vlm_model(const ModelContext& context) {
return nullptr;
}

std::unique_ptr<EmbeddingVLM> create_vlm_embedding_model(
const ModelContext& context) {
std::string resolved_name;
std::string error_message;
if (!resolve_model_registration_name(context.get_model_args().model_type(),
&resolved_name,
&error_message)) {
LOG(ERROR) << error_message;
return nullptr;
}

auto factory = ModelRegistry::get_embeddingvlm_factory(resolved_name);
if (factory) {
return factory(context);
}

LOG(ERROR) << "Unsupported model type: "
<< context.get_model_args().model_type();

return nullptr;
}

std::unique_ptr<MMEmbeddingVLM> create_vlm_mm_embedding_model(
const ModelContext& context) {
std::string resolved_name;
Expand Down
29 changes: 0 additions & 29 deletions xllm/models/model_registry.h
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@ limitations under the License.
#include "core/framework/model/causal_lm.h"
#include "core/framework/model/causal_vlm.h"
#include "core/framework/model/dit_model.h"
#include "core/framework/model/embedding_vlm.h"
#include "core/framework/model/mm_embedding_vlm.h"
#include "core/framework/model_context.h"
#include "core/framework/tokenizer/tokenizer_args.h"
Expand All @@ -44,9 +43,6 @@ using RecModelFactory =
using CausalVLMFactory =
std::function<std::unique_ptr<CausalVLM>(const ModelContext& context)>;

using EmbeddingVLMFactory =
std::function<std::unique_ptr<EmbeddingVLM>(const ModelContext& context)>;

using MMEmbeddingVLMFactory =
std::function<std::unique_ptr<MMEmbeddingVLM>(const ModelContext& context)>;

Expand All @@ -73,7 +69,6 @@ struct ModelMeta {
CausalLMFactory causal_lm_factory;
RecModelFactory rec_model_factory;
CausalVLMFactory causal_vlm_factory;
EmbeddingVLMFactory embedding_vlm_factory;
MMEmbeddingVLMFactory mm_embedding_vlm_factory;
DiTModelFactory dit_model_factory;
InputProcessorFactory input_processor_factory;
Expand All @@ -98,9 +93,6 @@ class ModelRegistry {
static void register_causalvlm_factory(const std::string& name,
CausalVLMFactory factory);

static void register_vlm_embedding_factory(const std::string& name,
EmbeddingVLMFactory factory);

static void register_mm_embedding_vlm_factory(const std::string& name,
MMEmbeddingVLMFactory factory);

Expand All @@ -127,8 +119,6 @@ class ModelRegistry {

static CausalVLMFactory get_causalvlm_factory(const std::string& name);

static EmbeddingVLMFactory get_embeddingvlm_factory(const std::string& name);

static MMEmbeddingVLMFactory get_mm_embedding_vlm_factory(
const std::string& name);

Expand Down Expand Up @@ -169,9 +159,6 @@ std::unique_ptr<CausalLM> create_rec_model(const ModelContext& context);

std::unique_ptr<CausalVLM> create_vlm_model(const ModelContext& context);

std::unique_ptr<EmbeddingVLM> create_vlm_embedding_model(
const ModelContext& context);

std::unique_ptr<MMEmbeddingVLM> create_vlm_mm_embedding_model(
const ModelContext& context);

Expand Down Expand Up @@ -223,22 +210,6 @@ std::unique_ptr<DiTModel> create_dit_model(const DiTModelContext& context);
#define REGISTER_CAUSAL_VLM_MODEL(ModelType, ModelClass) \
REGISTER_CAUSAL_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass)

#define REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME( \
VarName, ModelType, ModelClass) \
const bool VarName##_registered = []() { \
ModelRegistry::register_vlm_embedding_factory( \
#ModelType, [](const ModelContext& context) { \
ModelClass model(context); \
model->eval(); \
return std::make_unique<xllm::EmbeddingVLMImpl<ModelClass>>( \
std::move(model), context.get_tensor_options()); \
}); \
return true; \
}()

#define REGISTER_EMBEDDING_VLM_MODEL(ModelType, ModelClass) \
REGISTER_EMBEDDING_VLM_MODEL_WITH_VARNAME(ModelType, ModelType, ModelClass)

#define REGISTER_MM_EMBEDDING_VLM_MODEL_WITH_VARNAME( \
VarName, ModelType, ModelClass) \
const bool VarName##_registered = []() { \
Expand Down
33 changes: 15 additions & 18 deletions xllm/models/models.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,27 +50,25 @@ limitations under the License.
#include "vlm/npu/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_5_vl_mm_embedding.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/npu/qwen2_vl_embedding.h" // IWYU pragma: keep
#include "vlm/npu/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/npu/qwen3_vl_mm_embedding.h" // IWYU pragma: keep
#include "vlm/npu/qwen3_vl_moe.h" // IWYU pragma: keep
#elif defined(USE_MLU)
#include "llm/deepseek_mtp.h" // IWYU pragma: keep
#include "llm/deepseek_v2.h" // IWYU pragma: keep
#include "llm/deepseek_v3.h" // IWYU pragma: keep
#include "llm/deepseek_v32.h" // IWYU pragma: keep
#include "llm/glm5.h" // IWYU pragma: keep
#include "llm/glm5_mtp.h" // IWYU pragma: keep
#include "llm/joyai_llm_flash.h" // IWYU pragma: keep
#include "llm/mtp_model_base.h" // IWYU pragma: keep
#include "llm/qwen2.h" // IWYU pragma: keep
#include "llm/qwen3.h" // IWYU pragma: keep
#include "llm/qwen3_moe.h" // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep
#include "vlm/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep
#include "llm/deepseek_mtp.h" // IWYU pragma: keep
#include "llm/deepseek_v2.h" // IWYU pragma: keep
#include "llm/deepseek_v3.h" // IWYU pragma: keep
#include "llm/deepseek_v32.h" // IWYU pragma: keep
#include "llm/glm5.h" // IWYU pragma: keep
#include "llm/glm5_mtp.h" // IWYU pragma: keep
#include "llm/joyai_llm_flash.h" // IWYU pragma: keep
#include "llm/mtp_model_base.h" // IWYU pragma: keep
#include "llm/qwen2.h" // IWYU pragma: keep
#include "llm/qwen3.h" // IWYU pragma: keep
#include "llm/qwen3_moe.h" // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep
#elif defined(USE_ILU)
#include "llm/qwen2.h" // IWYU pragma: keep
#include "llm/qwen3.h" // IWYU pragma: keep
Expand All @@ -83,7 +81,6 @@ limitations under the License.
#include "llm/qwen3_moe.h" // IWYU pragma: keep
#include "vlm/qwen2_5_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl.h" // IWYU pragma: keep
#include "vlm/qwen2_vl_embedding.h" // IWYU pragma: keep
#include "vlm/qwen3_vl.h" // IWYU pragma: keep
#include "vlm/qwen3_vl_moe.h" // IWYU pragma: keep
#elif defined(USE_MUSA)
Expand Down
5 changes: 5 additions & 0 deletions xllm/models/vlm/npu/glm4v.h
Original file line number Diff line number Diff line change
Expand Up @@ -1006,6 +1006,11 @@ class Glm4vForConditionalGenerationImpl : public torch::nn::Module {
return language_model_(tokens, positions, kv_caches, input_params);
}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->pooler(hidden_states, seleted_idxes);
}

torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->logits(hidden_states, seleted_idxes);
Expand Down
5 changes: 5 additions & 0 deletions xllm/models/vlm/npu/glm4v_moe.h
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,11 @@ class Glm4vMoeForConditionalGenerationImpl : public torch::nn::Module {
return language_model_(tokens, positions, kv_caches, input_params);
}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->pooler(hidden_states, seleted_idxes);
}

torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->logits(hidden_states, seleted_idxes);
Expand Down
6 changes: 6 additions & 0 deletions xllm/models/vlm/npu/minicpmv.h
100755 → 100644
Original file line number Diff line number Diff line change
Expand Up @@ -1214,6 +1214,12 @@ class MiniCPMV2_6Impl : public torch::nn::Module {
const ModelInputParams& input_params) {
return language_model_(tokens, positions, kv_caches, input_params);
}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->pooler(hidden_states, seleted_idxes);
}

torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->logits(hidden_states, seleted_idxes);
Expand Down
5 changes: 5 additions & 0 deletions xllm/models/vlm/npu/oxygen_vlm.h
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,11 @@ class OxygenvlmForConditionalGenerationImpl : public torch::nn::Module {
return emb;
}

torch::Tensor pooler(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->pooler(hidden_states, seleted_idxes);
}

torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) {
return language_model_->logits(hidden_states, seleted_idxes);
Expand Down
Loading
Loading