Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion third_party/xllm_atb_layers
Submodule xllm_atb_layers updated from 65f62d to eacdb1
37 changes: 37 additions & 0 deletions xllm/core/framework/model/model_args.h
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ struct ModelArgs {
PROPERTY(int32_t, rope_scaling) = -1;
PROPERTY(float, router_aux_loss_coef) = 0.001f;

// qwen3 omni
PROPERTY(int64_t, talker_text_vocab_size) = 3072;
PROPERTY(int64_t, thinker_hidden_size) = 2048;
PROPERTY(int64_t, talker_text_intermediate_size) = 2048;
PROPERTY(int64_t, talker_text_hidden_size) = 1024;
PROPERTY(int32_t, talker_num_code_groups) = 16;

// Vision model's dropout
PROPERTY(float, mm_dropout) = 0.0f;

Expand Down Expand Up @@ -408,6 +415,36 @@ struct ModelArgs {
PROPERTY(float, max_shift) = 0;
PROPERTY(int64_t, base_image_seq_len) = 0;
PROPERTY(int64_t, max_image_seq_len) = 0;

// qwen3_omni_code2wavargs
PROPERTY(int64_t, code2wav_config_codebook_size) = 2048;
PROPERTY(int64_t, code2wav_config_hidden_size) = 1024;
PROPERTY(int64_t, code2wav_config_max_position_embeddings) = 8000;
PROPERTY(double, code2wav_config_rope_theta) = 10000.0;
PROPERTY(int64_t, code2wav_config_num_attention_heads) = 16;
PROPERTY(int64_t, code2wav_config_num_key_value_heads) = 16;
PROPERTY(bool, code2wav_config_attention_bias) = false;
PROPERTY(int64_t, code2wav_config_sliding_window) = 72;
PROPERTY(int64_t, code2wav_config_intermediate_size) = 3072;
PROPERTY(std::string, code2wav_config_hidden_act) = std::string("silu");
PROPERTY(double, code2wav_config_layer_scale_initial_scale) = 0.01;
PROPERTY(double, code2wav_config_rms_norm_eps) = 1e-5;
PROPERTY(int64_t, code2wav_config_num_hidden_layers) = 8;
PROPERTY(int64_t, code2wav_config_num_quantizers) = 16;
PROPERTY(int64_t, code2wav_config_decoder_dim) = 1536;
PROPERTY(double, code2wav_config_attention_dropout) = 0.0;
PROPERTY(std::vector<int>,
code2wav_config_upsampling_ratios_vec) = std::vector<int> {
2,
2
};
PROPERTY(std::vector<int>,
code2wav_config_upsample_rates_vec) = std::vector<int> {
8,
5,
4,
3
};
};

inline std::ostream& operator<<(std::ostream& os, const ModelArgs& args) {
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -579,6 +579,8 @@ struct ModelInputParams {

// Flag for CUDA graph capture mode
bool enable_cuda_graph = false;
int32_t generation_steps =
0; // TODO(panxuanyu) is there any way to get generation_steps
};

} // namespace xllm
3 changes: 2 additions & 1 deletion xllm/core/framework/sampling/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ cc_test(
:flags
:sampler
glog::glog
:state_dict
)
target_link_libraries(sampler_test PRIVATE brpc OpenSSL::SSL OpenSSL::Crypto leveldb::leveldb protobuf::libprotobuf)
target_link_libraries(sampler_test
Expand All @@ -51,4 +52,4 @@ target_link_libraries(sampler_test
$<$<BOOL:${USE_NPU}>:hccl>
$<$<BOOL:${USE_NPU}>:c_sec>
$<$<BOOL:${USE_NPU}>:nnopbase>)
add_dependencies(sampler_test brpc-static)
add_dependencies(sampler_test brpc-static)
66 changes: 62 additions & 4 deletions xllm/core/framework/state_dict/state_dict.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ void destroy_memory_mapping(MemoryMapping* mapping) {
if (mapping->fd != -1) {
close(mapping->fd);
}
delete mapping;
free(mapping);
}
}
} // namespace
Expand Down Expand Up @@ -145,9 +145,6 @@ torch::Tensor StateDict::get_sharded_tensor(const std::string& tensor_name,
<< "Invalid rank " << rank << " for " << world_size << " shards";

auto tensor = get_tensor(tensor_name);
if (!tensor.defined()) {
return tensor;
}
// chunk tensor along the dim
const int64_t dim_size = tensor.size(dim);
if (dim_size < world_size) {
Expand Down Expand Up @@ -182,6 +179,67 @@ StateDict StateDict::get_dict_with_prefix(
return tensors;
}

StateDict StateDict::get_dict_with_renamed_prefix(
const std::string& old_prefix,
const std::string& new_prefix) const {
bool old_end_with_dot = old_prefix.empty() || old_prefix.back() == '.';
bool new_end_with_dot = new_prefix.empty() || new_prefix.back() == '.';
if ((old_end_with_dot && !new_end_with_dot) ||
(!old_end_with_dot && new_end_with_dot)) {
LOG(WARNING) << "The input does not comply the standards"
<< "you are changing statedict prefixs from " << old_prefix
<< " to " << new_prefix;
}

std::unordered_map<std::string, torch::Tensor> tensors;

for (const auto& [name, tensor] : dict_) {
if (absl::StartsWith(name, old_prefix)) {
std::string new_name = new_prefix + name.substr(old_prefix.length());
tensors[std::move(new_name)] = tensor;
} else {
tensors[name] = tensor;
}
}

return {std::move(tensors), prefix_};
}

void StateDict::rename_prefix_inplace(const std::string& old_prefix,
const std::string& new_prefix) {
bool old_end_with_dot = old_prefix.empty() || old_prefix.back() == '.';
bool new_end_with_dot = new_prefix.empty() || new_prefix.back() == '.';
if ((old_end_with_dot && !new_end_with_dot) ||
(!old_end_with_dot && new_end_with_dot)) {
LOG(WARNING) << "The input does not comply the standards"
<< "you are changing statedict prefixs from " << old_prefix
<< " to " << new_prefix;
}

std::unordered_map<std::string, torch::Tensor> tensors;

for (auto it = dict_.begin(); it != dict_.end();) {
if (absl::StartsWith(it->first, old_prefix)) {
std::string new_key = new_prefix + it->first.substr(old_prefix.length());
VLOG(50) << "Replacing " << it->first << " with new name " << new_key
<< ", old_prefix " << old_prefix << " , new_prefix "
<< new_prefix;
tensors[std::move(new_key)] = std::move(it->second);
it = dict_.erase(it);
} else {
++it;
}
}

if (tensors.empty()) {
LOG(FATAL) << "Rename_prefix_inplace failed, because no match prefix "
<< old_prefix << " .";
}
for (auto& [new_key, tensor] : tensors) {
dict_[std::move(new_key)] = std::move(tensor);
}
}

StateDictFromSafeTensor::StateDictFromSafeTensor(
std::unique_ptr<MemoryMapping> mem_map,
std::unordered_map<std::string, torch::Tensor> dict)
Expand Down
16 changes: 16 additions & 0 deletions xllm/core/framework/state_dict/state_dict.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,22 @@ class StateDict {
virtual StateDict get_dict_with_prefix(const std::string& prefix,
TensorTransform transform_func) const;

virtual StateDict get_dict_with_renamed_prefix(
const std::string& old_prefix,
const std::string& new_prefix) const;

virtual void rename_prefix_inplace(const std::string& old_prefix,
const std::string& new_prefix);

std::vector<std::string> get_all_tensor_names() const {
std::vector<std::string> names;
names.reserve(dict_.size());
for (const auto& kv : dict_) {
names.emplace_back(kv.first);
}
return names;
}

size_t size() const { return dict_.size(); }

std::string_view prefix() const { return prefix_; }
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/layers/npu/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ cc_library(
npu_glm4_decoder_layer_impl.h
npu_rms_norm_impl.h
npu_siglip_encoder_layer_impl.h
npu_qwen3_omni_code2wav_transformer_layer_impl.h
../common/rotary_embedding_util.h
loader/qwen3_decoder_loader.h
loader/qwen2_decoder_loader.h
Expand Down Expand Up @@ -90,6 +91,7 @@ cc_library(
npu_glm4_decoder_layer_impl.cpp
npu_rms_norm_impl.cpp
npu_siglip_encoder_layer_impl.cpp
npu_qwen3_omni_code2wav_transformer_layer_impl.h
../common/rotary_embedding_util.cpp
loader/qwen3_decoder_loader.cpp
loader/qwen2_decoder_loader.cpp
Expand Down
55 changes: 50 additions & 5 deletions xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,11 @@ static const std::unordered_map<std::string, int> WEIGHT_MAPPING = {
// Expert MLP - Down projection
{"down_proj.weight", IN_MLP_DOWN_WEIGHT_EXPERT},

// Shared Expert MLP
{"mlp.shared_expert.gate_proj.weight", IN_MLP_SHARED_GATEUP_WEIGHT},
{"mlp.shared_expert.up_proj.weight", IN_MLP_SHARED_GATEUP_WEIGHT},
{"mlp.shared_expert.down_proj.weight", IN_MLP_SHARED_DOWN_WEIGHT},
{"mlp.shared_expert_gate.weight", IN_MLP_SHARED_EXPERT_GATE},
};

static const std::unordered_map<std::string, int> WEIGHT_MAPPING_W8A8 = {
Expand Down Expand Up @@ -189,6 +194,8 @@ static const std::map<int, int> WEIGHT_SHARD = {
{IN_ATTENTION_OUT_WEIGHT, 1},
{IN_MLP_GATEUP_WEIGHT_EXPERT, 0},
{IN_MLP_DOWN_WEIGHT_EXPERT, 1},
{IN_MLP_SHARED_GATEUP_WEIGHT, 0},
{IN_MLP_SHARED_DOWN_WEIGHT, 1},
};

static const std::map<int, int> WEIGHT_SHARD_W8A8 = {
Expand Down Expand Up @@ -219,6 +226,7 @@ Qwen3MoeDecoderLoader::Qwen3MoeDecoderLoader(uint64_t weight_count,
}

num_experts_ = model_args.num_experts();
num_shared_experts_ = model_args.n_shared_experts();
ep_size_ = parallel_args_.ep_size();
ep_local_tp_size_ = parallel_args_.world_size() / ep_size_;
CHECK_EQ(parallel_args_.world_size(), ep_size_ * ep_local_tp_size_);
Expand All @@ -245,8 +253,9 @@ void Qwen3MoeDecoderLoader::load_state_dict(const StateDict& state_dict) {
continue;
}

if (absl::StartsWith(name, "mlp") && !absl::StrContains(name, "gate.")) {
process_mlp_common_weights(state_dict, name, tensor);
if (absl::StartsWith(name, "mlp.shared_expert") &&
!absl::StrContains(name, "gate.")) {
process_shared_expert_weights(state_dict, name, tensor);
continue;
}

Expand All @@ -258,9 +267,18 @@ void Qwen3MoeDecoderLoader::verify_loaded_weights(
const std::string& prefix) const {
for (const auto& [name, index] : WEIGHT_MAPPING) {
if (name == "down_proj.weight" || name == "gate_proj.weight" ||
name == "up_proj.weight") {
name == "up_proj.weight" ||
name == "mlp.shared_expert.gate_proj.weight" ||
name == "mlp.shared_expert.up_proj.weight") {
continue;
}

if (num_shared_experts_ <= 0 &&
(name == "mlp.shared_expert_gate.weight" ||
name == "mlp.shared_expert.down_proj.weight")) {
continue;
}

CHECK(at_weight_tensors_[index].sizes() != std::vector<int64_t>({1}))
<< "weight is not loaded for " << name;
}
Expand Down Expand Up @@ -325,6 +343,29 @@ void Qwen3MoeDecoderLoader::merge_experts_weights() {
}
}

void Qwen3MoeDecoderLoader::merge_shared_expert_weights() {
if (num_shared_experts_ > 0) {
try {
at_weight_tensors_[IN_MLP_SHARED_GATEUP_WEIGHT] =
torch::cat(
{shared_experts_weights_["mlp.shared_expert.gate_proj.weight"],
shared_experts_weights_["mlp.shared_expert.up_proj.weight"]},
0)
.to(device_)
.contiguous();
shared_experts_weights_["mlp.shared_expert.gate_proj.weight"] =
torch::zeros({1}, torch::kFloat16).to(device_);

shared_experts_weights_["mlp.shared_expert.up_proj.weight"] =
torch::zeros({1}, torch::kFloat16).to(device_);
} catch (const std::exception& e) {
LOG(ERROR) << "[ERROR] Exception in shared gateup weight processing: "
<< e.what();
throw;
}
}
}

torch::Tensor Qwen3MoeDecoderLoader::merge_experts_weights(
std::vector<torch::Tensor>& experts,
bool transpose) {
Expand Down Expand Up @@ -436,7 +477,7 @@ int Qwen3MoeDecoderLoader::get_mapped_index(
return it->second;
}

void Qwen3MoeDecoderLoader::process_mlp_common_weights(
void Qwen3MoeDecoderLoader::process_shared_expert_weights(
const StateDict& state_dict,
const std::string& name,
const torch::Tensor& tensor) {
Expand All @@ -461,6 +502,8 @@ void Qwen3MoeDecoderLoader::process_mlp_common_weights(
at_weight_tensors_[index] = tmp_tensor;
} else {
shared_experts_weights_[name] = tmp_tensor;
LOG(INFO) << "PXY Load Shared experts: " << name << " , device is "
<< shared_experts_weights_[name].device();
}
}

Expand Down Expand Up @@ -571,6 +614,8 @@ torch::Tensor Qwen3MoeDecoderLoader::merge_experts_weights(

void Qwen3MoeDecoderLoader::merge_loaded_weights() {
merge_experts_weights();
merge_shared_expert_weights();

at_weight_tensors_[IN_QKV_WEIGHT_0] =
torch::cat({at_weight_tensors_[IN_QKV_WEIGHT_0],
at_weight_tensors_[IN_QKV_WEIGHT_1],
Expand Down Expand Up @@ -632,4 +677,4 @@ void Qwen3MoeDecoderLoader::merge_loaded_weights() {
}

} // namespace layer
} // namespace xllm
} // namespace xllm
10 changes: 6 additions & 4 deletions xllm/core/layers/npu/loader/qwen3_moe_decoder_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,15 +50,16 @@ class Qwen3MoeDecoderLoader : public BaseLoader {
int local_tp_rank,
int local_tp_size);

void process_mlp_common_weights(const StateDict& state_dict,
const std::string& name,
const torch::Tensor& tensor);
void process_shared_expert_weights(const StateDict& state_dict,
const std::string& name,
const torch::Tensor& tensor);

void process_general_weights(const StateDict& state_dict,
const std::string& name,
const torch::Tensor& tensor);

void merge_experts_weights();
void merge_shared_expert_weights();

torch::Tensor merge_experts_weights(std::vector<torch::Tensor>& experts_up,
std::vector<torch::Tensor>& experts_gate,
Expand All @@ -74,6 +75,7 @@ class Qwen3MoeDecoderLoader : public BaseLoader {
int32_t ep_size_;
int32_t num_experts_;
int32_t num_experts_per_partition_;
int32_t num_shared_experts_;
int32_t ep_local_tp_size_;
int32_t ep_local_tp_rank_;
int32_t start_expert_id_;
Expand All @@ -91,4 +93,4 @@ class Qwen3MoeDecoderLoader : public BaseLoader {
};

} // namespace layer
} // namespace xllm
} // namespace xllm
9 changes: 6 additions & 3 deletions xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -158,8 +158,12 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_mlp_parameters(
atb_speed::qwen::MoeDecoderLayerParam& param,
const ModelArgs& args,
const ParallelArgs& parallel_args) {
param.hasSharedExpert = (args.n_shared_experts() > 0);
param.hasSharedExpertGate = false;
if (args.model_type() == "qwen3_omni_moe_talker") {
param.hasSharedExpert = true;
param.hasSharedExpertGate = true;
} else {
param.hasSharedExpertGate = false;
}
param.processLogits = "normalization";
param.numOfSelectedExperts = {args.num_experts_per_tok()};

Expand Down Expand Up @@ -312,7 +316,6 @@ torch::Tensor NpuQwen3MoeDecoderLayerImpl::forward(
LOG_IF(FATAL, st != 0) << model_name_
<< "excute decode layer fail, error code: " << st;
}

return tensor_placeholder_;
}

Expand Down
Loading