Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions xllm/core/framework/model/model_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,9 @@ struct ModelArgs {
// Vision model's mm_projection_dim
PROPERTY(int64_t, mm_projection_dim) = 0;

// Vision model's mm_projector_hidden_size
PROPERTY(int64_t, mm_projector_hidden_size) = 0;

PROPERTY(int64_t, mm_spatial_merge_size) = 0;
PROPERTY(int64_t, mm_spatial_patch_size) = 0;

Expand Down
2 changes: 2 additions & 0 deletions xllm/core/layers/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -37,13 +37,15 @@ cc_library(
layers
HDRS
onerec_block_layer.h
oxygen_vision_layer.h
qwen2_decoder_layer.h
qwen2_vision_layer.h
qwen2_5_vision_layer.h
qwen3_vision_layer.h
qwen3_decoder_layer.h
qwen3_moe_decoder_layer.h
SRCS
oxygen_vision_layer.cpp
qwen2_vision_layer.cpp
qwen2_decoder_layer.cpp
qwen2_5_vision_layer.cpp
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/layers/common/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ cc_library(
NAME
common_layers
HDRS
oxygen_vision_attention.h
qwen2_attention.h
qwen2_vision_attention.h
qwen3_next_rms_norm.h
Expand All @@ -24,6 +25,7 @@ cc_library(
add_matmul.h
moe_fused_topk.h
SRCS
oxygen_vision_attention.cpp
qwen2_attention.cpp
qwen2_vision_attention.cpp
qwen3_next_rms_norm.cpp
Expand Down
117 changes: 117 additions & 0 deletions xllm/core/layers/common/oxygen_vision_attention.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "oxygen_vision_attention.h"

#if defined(USE_MLU)
#include "kernels/mlu/mlu_ops_api.h"
#endif
#include "kernels/ops_api.h"
namespace xllm {
namespace layer {

OxygenVisionAttentionImpl::OxygenVisionAttentionImpl(
const ModelContext& context)
: Qwen2VisionAttentionImpl(context, false) {}

torch::Tensor OxygenVisionAttentionImpl::forward(
torch::Tensor& hidden_states,
torch::Tensor& m_cos_pos,
torch::Tensor& m_sin_pos,
torch::Tensor& cu_seq_len,
std::vector<int32_t>& cu_seq_len_vec,
ModelInputParams& params) {
// 1. qkv projection
auto qkv = qkv_proj_->forward(hidden_states);
// 2. split qkv
auto qkv_split = split_qkv(qkv);
// 3. transpose [s, b, h, d] -> [b, s, h, d]
for (auto& tensor : qkv_split) {
tensor = tensor.transpose(0, 1).contiguous();
}
auto q = qkv_split[0];
auto k = qkv_split[1];
auto v = qkv_split[2];
int64_t B = q.size(0);
int64_t S = q.size(1);
int64_t head_dim = q.size(3);
CHECK_EQ(head_dim, hidden_size_per_attention_head_) << "head_dim mismatch";
int32_t max_seqlen =
*std::max_element(cu_seq_len_vec.begin(), cu_seq_len_vec.end());

// 4. rope
// Reshape q, k from [B, S, H, D] to [B*S, H, D] before applying RoPE so
// that the RoPE kernel sees the correct total token count (B*S = seq_len),
// not just the batch dimension (B=1).
q = q.reshape({B * S, num_attention_heads_per_partition_, head_dim});
k = k.reshape({B * S, num_attention_heads_per_partition_, head_dim});

// Apply rotary position embedding to both q and k seperately.
xllm::kernel::RotaryParams rotary_params;
rotary_params.q = q;
rotary_params.sin = m_sin_pos;
rotary_params.cos = m_cos_pos;
rotary_params.interleaved = false;
rotary_params.discrete = false;
rotary_params.cu_query_lens = cu_seq_len;
rotary_params.max_query_len = max_seqlen;
xllm::kernel::apply_rotary(rotary_params);
rotary_params.q = k;
xllm::kernel::apply_rotary(rotary_params);

// q, k, v = (rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v])
// q and k are already [B*S, H, D] after the reshape above; just
// flatten v to the same shape.
v = v.view({B * S, v.size(2), v.size(3)});
torch::Tensor output = torch::zeros_like(q);

// 5. store k/v cache and do attention
#if defined(USE_MLU)
std::optional<torch::Tensor> output_lse = std::nullopt;

xllm::kernel::mlu::batch_prefill(q,
k,
v,
output,
output_lse,
cu_seq_len,
cu_seq_len,
/*alibi_slope=*/std::nullopt,
/*alibi_bias=*/std::nullopt,
/*q_quant_scale=*/std::nullopt,
/*k_quant_scale=*/std::nullopt,
/*v_quant_scale=*/std::nullopt,
/*out_quant_scale=*/std::nullopt,
/*block_table=*/std::nullopt,
max_seqlen,
max_seqlen,
scale_,
/*is_causal=*/false,
/*window_size_left=*/-1,
/*window_size_right=*/-1,
/*compute_dtype=*/"half",
/*return_lse=*/false);
#endif

// context_layer = rearrange(output, "(b s) h d -> s b (h d)", b=batch_size)
output = output.view({B, S, -1});
// [B, S, ...] -> [S, B, ...]
output = output.transpose(0, 1).reshape({-1, output.size(-1)});
// 6. output projection
return proj_->forward(output);
}

} // namespace layer
} // namespace xllm
42 changes: 42 additions & 0 deletions xllm/core/layers/common/oxygen_vision_attention.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/torch.h>

#include "framework/model/model_input_params.h"
#include "framework/model_context.h"
#include "qwen2_vision_attention.h"

namespace xllm {
namespace layer {

class OxygenVisionAttentionImpl : public Qwen2VisionAttentionImpl {
public:
OxygenVisionAttentionImpl() = default;
OxygenVisionAttentionImpl(const ModelContext& context);

torch::Tensor forward(torch::Tensor& hidden_states,
torch::Tensor& m_cos_pos,
torch::Tensor& m_sin_pos,
torch::Tensor& cu_seq_len,
std::vector<int32_t>& cu_seq_len_vec,
ModelInputParams& input_params) override;
};
TORCH_MODULE(OxygenVisionAttention);

} // namespace layer
} // namespace xllm
2 changes: 1 addition & 1 deletion xllm/core/layers/common/qwen2_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ limitations under the License.
namespace {
inline bool is_qwen3_model(const std::string& model_type) {
static const std::unordered_set<std::string> qwen3_type_set = {
"qwen3", "qwen3_vl", "qwen3_moe", "qwen3_vl_moe"};
"qwen3", "qwen3_vl", "qwen3_moe", "qwen3_vl_moe", "oxygenvlm"};
return qwen3_type_set.contains(model_type);
}

Expand Down
8 changes: 4 additions & 4 deletions xllm/core/layers/common/qwen2_vision_attention.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,8 @@ limitations under the License.
namespace xllm {
namespace layer {

Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl(
const ModelContext& context) {
Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl(const ModelContext& context,
bool has_bias) {
const auto& args = context.get_model_args();
const auto& quant_args = context.get_quant_args();
const auto& parallel_args = context.get_parallel_args();
Expand All @@ -47,15 +47,15 @@ Qwen2VisionAttentionImpl::Qwen2VisionAttentionImpl(
num_attention_heads_per_partition_,
hidden_size_per_attention_head_,
/*num_kv_head_replicas=*/1,
/*bias=*/true,
/*bias=*/has_bias,
/*gather_output=*/false,
parallel_args,
options));

proj_ = register_module("proj",
RowParallelLinear(hidden_size,
hidden_size,
/*bias=*/true,
/*bias=*/has_bias,
/*input_is_parallelized=*/true,
/*if_reduce_results=*/true,
quant_args,
Expand Down
16 changes: 8 additions & 8 deletions xllm/core/layers/common/qwen2_vision_attention.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,18 +31,18 @@ namespace layer {
class Qwen2VisionAttentionImpl : public torch::nn::Module {
public:
Qwen2VisionAttentionImpl() = default;
Qwen2VisionAttentionImpl(const ModelContext& context);
Qwen2VisionAttentionImpl(const ModelContext& context, bool has_bias = true);

torch::Tensor forward(torch::Tensor& hidden_states,
torch::Tensor& m_cos_pos,
torch::Tensor& m_sin_pos,
torch::Tensor& cu_seq_len,
std::vector<int32_t>& cu_seq_len_vec,
ModelInputParams& input_params);
virtual torch::Tensor forward(torch::Tensor& hidden_states,
torch::Tensor& m_cos_pos,
torch::Tensor& m_sin_pos,
torch::Tensor& cu_seq_len,
std::vector<int32_t>& cu_seq_len_vec,
ModelInputParams& input_params);

void load_state_dict(const StateDict& state_dict);

private:
protected:
std::vector<torch::Tensor> split_qkv(const torch::Tensor& qkv);

int64_t hidden_size_per_attention_head_;
Expand Down
72 changes: 72 additions & 0 deletions xllm/core/layers/oxygen_vision_layer.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include "oxygen_vision_layer.h"

namespace xllm {
namespace layer {

OxygenVisionLayerImpl::OxygenVisionLayerImpl(const ModelContext& context) {
const auto& args = context.get_model_args();
const auto& quant_config = context.get_quant_args();
const auto& parallel_args = context.get_parallel_args();
const auto& options = context.get_tensor_options();
int64_t dim = args.mm_hidden_size();
int64_t mlp_intermediate_size = args.mm_intermediate_size();
attention_ = register_module("self_attn", OxygenVisionAttention(context));
norm1_ = register_module("norm1", RMSNorm(dim, args.rms_norm_eps(), options));
norm2_ = register_module("norm2", RMSNorm(dim, args.rms_norm_eps(), options));

mlp_ = register_module("mlp",
DenseMLP(dim,
args.mm_intermediate_size(),
/*is_gated=*/true,
/*has_bias=*/false,
args.mm_hidden_act(),
/*enable_result_reduction=*/true,
quant_config,
parallel_args.tp_group_,
options));
}

void OxygenVisionLayerImpl::load_state_dict(const StateDict& state_dict) {
attention_->load_state_dict(state_dict.get_dict_with_prefix("attn."));
mlp_->load_state_dict(state_dict.get_dict_with_prefix("mlp."));
norm1_->load_state_dict(state_dict.get_dict_with_prefix("norm1."));
norm2_->load_state_dict(state_dict.get_dict_with_prefix("norm2."));
}

torch::Tensor OxygenVisionLayerImpl::forward(
torch::Tensor& hidden_states,
torch::Tensor& m_cos_pos,
torch::Tensor& m_sin_pos,
torch::Tensor& cu_seq_len,
std::vector<int32_t>& cu_seq_len_vec,
ModelInputParams& input_params,
int node_id) {
auto norm_output1 = std::get<0>(norm1_(hidden_states));
auto output = hidden_states + attention_(norm_output1,
m_cos_pos,
m_sin_pos,
cu_seq_len,
cu_seq_len_vec,
input_params);
auto norm_output2 = std::get<0>(norm2_(output));
output = output + mlp_(norm_output2);
return output;
}

} // namespace layer
} // namespace xllm
56 changes: 56 additions & 0 deletions xllm/core/layers/oxygen_vision_layer.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
/* Copyright 2026 The xLLM Authors. All Rights Reserved.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

https://github.com/jd-opensource/xllm/blob/main/LICENSE

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#pragma once

#include <torch/torch.h>

#include <functional>

#include "common/dense_mlp.h"
#include "common/oxygen_vision_attention.h"
#include "common/rms_norm.h"
#include "framework/model/model_args.h"
#include "framework/model/model_input_params.h"
#include "framework/model_context.h"
#include "framework/state_dict/state_dict.h"

namespace xllm {
namespace layer {

class OxygenVisionLayerImpl : public torch::nn::Module {
public:
OxygenVisionLayerImpl(const ModelContext& context);

void load_state_dict(const StateDict& state_dict);

torch::Tensor forward(torch::Tensor& x,
torch::Tensor& m_cos_pos,
torch::Tensor& m_sin_pos,
torch::Tensor& cu_seq_len,
std::vector<int32_t>& cu_seq_len_vec,
ModelInputParams& input_params,
int node_id);

private:
OxygenVisionAttention attention_{nullptr};
DenseMLP mlp_{nullptr};
RMSNorm norm1_{nullptr};
RMSNorm norm2_{nullptr};
};
TORCH_MODULE(OxygenVisionLayer);

} // namespace layer
} // namespace xllm
Loading
Loading