Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions xllm/core/distributed_runtime/master.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,7 @@ Master::Master(const Options& options, EngineType type)
.dp_size(options.dp_size())
.ep_size(options.ep_size())
.enable_prefill_sp(options_.enable_prefill_sp())
.cp_size(options.cp_size())
.enable_chunked_prefill(options_.enable_chunked_prefill())
.max_seqs_per_batch(options_.max_seqs_per_batch())
.max_tokens_per_chunk_for_prefill(
Expand Down
37 changes: 34 additions & 3 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -69,11 +69,13 @@ BatchInputBuilder::BatchInputBuilder(
batch_id_(batch_id),
cp_size_(std::max(1, cp_size)) {
// Reserve space for better performance
state_.flatten_tokens_vec.reserve(1000);
state_.flatten_positions_vec.reserve(1000);
const size_t reserve_size = 1024;
state_.flatten_tokens_vec.reserve(reserve_size);
state_.flatten_positions_vec.reserve(reserve_size);
state_.mrope_positions_vec.reserve(sequences.size());
state_.block_tables_vec.reserve(sequences.size());
state_.acc_logprob_vec.reserve(sequences.size());
state_.mtp_shifted_token_ids.reserve(reserve_size);
if (args_ != nullptr) {
use_mrope_ = (args_->rope_scaling_rope_type() == "mrope");
}
Expand Down Expand Up @@ -241,6 +243,9 @@ void BatchInputBuilder::process_sequences_multithreaded() {
state_.extra_token_ids.insert(state_.extra_token_ids.end(),
state.extra_token_ids.begin(),
state.extra_token_ids.end());
state_.mtp_shifted_token_ids.insert(state_.mtp_shifted_token_ids.end(),
state.mtp_shifted_token_ids.begin(),
state.mtp_shifted_token_ids.end());
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Whats the difference between extra_token_ids and mtp_shifted_token_ids ?

Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mtp_shifted_token_ids: represents the sequence that has been left-shifted and padded with -1 according to MTP prefill input prepare rules.

state_.transfer_kv_infos.insert(state_.transfer_kv_infos.end(),
state.transfer_kv_infos.begin(),
state.transfer_kv_infos.end());
Expand Down Expand Up @@ -337,6 +342,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
uint32_t padded_seq_len,
BuilderState* state_ptr) {
BuilderState& state = state_ptr ? *state_ptr : state_;
const size_t seq_token_begin = state.flatten_tokens_vec.size();

const auto& token_ids = sequence->tokens();
const uint32_t n_tokens = token_ids.size();
Expand Down Expand Up @@ -394,14 +400,33 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence,
}

// Add extra token id
int32_t extra_token_id = -1;
if (n_tokens == seq_len) {
// last chunk of prefill and decode
// add -1 as extra token id
state.extra_token_ids.emplace_back(-1);
state.embedding_ids.emplace_back(sequence->get_embedding_id());
state.request_ids.emplace_back(sequence->request_id());
} else {
state.extra_token_ids.emplace_back(token_ids[seq_len]);
extra_token_id = token_ids[seq_len];
state.extra_token_ids.emplace_back(extra_token_id);
}

if (cp_size_ > 1 && state.batch_forward_type.is_prefill()) {
const uint32_t q_len = seq_len - n_kv_cache_tokens;
if (q_len > 1) {
state.mtp_shifted_token_ids.insert(
state.mtp_shifted_token_ids.end(),
state.flatten_tokens_vec.begin() + seq_token_begin + 1,
state.flatten_tokens_vec.begin() + seq_token_begin + q_len);
}
state.mtp_shifted_token_ids.emplace_back(extra_token_id);
if (padded_seq_len > seq_len) {
const int32_t pad_token_id = args_ ? args_->pad_token_id() : 0;
state.mtp_shifted_token_ids.insert(state.mtp_shifted_token_ids.end(),
padded_seq_len - seq_len,
pad_token_id);
}
}
}

Expand Down Expand Up @@ -588,6 +613,10 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
input_params.embedding_ids = std::move(state_.embedding_ids);
input_params.request_ids = std::move(state_.request_ids);
input_params.extra_token_ids = std::move(state_.extra_token_ids);
if (!state_.mtp_shifted_token_ids.empty()) {
input_params.mtp_shifted_token_ids =
torch::tensor(state_.mtp_shifted_token_ids, torch::kInt);
}

if (swap_block_transfer_infos_ != nullptr &&
swap_block_transfer_infos_->size() > 0) {
Expand Down Expand Up @@ -668,6 +697,8 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
raw_forward_input.embedding_ids = std::move(state_.embedding_ids);
raw_forward_input.request_ids = std::move(state_.request_ids);
raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids);
raw_forward_input.mtp_shifted_token_ids =
std::move(state_.mtp_shifted_token_ids);
// beam search kernel input
if (state_.acc_logprob_vec.size() > 0) {
raw_forward_input.acc_logprob_vec = std::move(state_.acc_logprob_vec);
Expand Down
1 change: 1 addition & 0 deletions xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,6 +104,7 @@ class BatchInputBuilder {
std::vector<int32_t> embedding_ids;
std::vector<std::string> request_ids;
std::vector<int32_t> extra_token_ids;
std::vector<int32_t> mtp_shifted_token_ids;
std::vector<TransferKVInfo> transfer_kv_infos;

// for continuous kvcache
Expand Down
21 changes: 21 additions & 0 deletions xllm/core/framework/model/causal_lm.h
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,17 @@ class CausalLM : public torch::nn::Module {
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes) = 0;

// hidden_states: [num_tokens, hidden_size]
// seleted_idxes: [num_tokens]
// out_hidden: [num_selected_tokens, hidden_size]
// returns: [num_selected_tokens, vocab_size]
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes,
torch::Tensor& out_hidden) {
NOT_IMPLEMENTED();
return torch::Tensor();
}

virtual void load_model(std::unique_ptr<ModelLoader> loader) = 0;

virtual torch::Device device() const = 0;
Expand Down Expand Up @@ -158,6 +169,16 @@ class CausalLMImpl : public CausalLM {
return model_->logits(hidden_states, seleted_idxes);
}

torch::Tensor logits(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes,
torch::Tensor& out_hidden) override {
if constexpr (detail::has_logits_with_hidden<Model>::value) {
return model_->logits(hidden_states, seleted_idxes, out_hidden);
} else {
return CausalLM::logits(hidden_states, seleted_idxes, out_hidden);
}
}

void load_model(std::unique_ptr<ModelLoader> loader) override {
model_->load_model(std::move(loader));
}
Expand Down
3 changes: 3 additions & 0 deletions xllm/core/framework/model/model_input_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,7 @@ struct ModelInputParams {
params.embedding_ids = std::move(embedding_ids);
params.request_ids = std::move(request_ids);
params.extra_token_ids = std::move(extra_token_ids);
params.mtp_shifted_token_ids = safe_to(mtp_shifted_token_ids, device, true);
params.dp_ep_padding_data = dp_ep_padding_data;
params.cp_ep_padding_data
.attn_padding_idx(
Expand Down Expand Up @@ -535,6 +536,8 @@ struct ModelInputParams {
// chunked prefill case of speculative decoding
// extra token ids for each sequence, and -1 for last chunk
std::vector<int32_t> extra_token_ids;
// precomputed shifted token ids for mtp prefill.
torch::Tensor mtp_shifted_token_ids;

// swap
std::vector<BlockTransferInfo> swap_blocks;
Expand Down
11 changes: 11 additions & 0 deletions xllm/core/framework/model/model_traits.h
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,17 @@ struct has_set_word_embedding<
std::void_t<decltype(std::declval<T>()->set_word_embedding(
std::declval<layer::WordEmbedding&>()))>> : std::true_type {};

template <typename T, typename = void>
struct has_logits_with_hidden : std::false_type {};

template <typename T>
struct has_logits_with_hidden<T,
std::void_t<decltype(std::declval<T>()->logits(
std::declval<const torch::Tensor&>(),
std::declval<const torch::Tensor&>(),
std::declval<torch::Tensor&>()))>>
: std::true_type {};

template <typename T, typename = void>
struct has_lazy_load_model : std::false_type {};

Expand Down
17 changes: 10 additions & 7 deletions xllm/core/framework/parallel_state/collective_communicator.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -119,19 +119,22 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank,
void CollectiveCommunicator::create_process_groups(
const std::string& master_addr,
const torch::Device& device) {
int global_rank = parallel_args_->rank();
int world_size = parallel_args_->world_size();
int dp_size = parallel_args_->dp_size();
int ep_size = parallel_args_->ep_size();
int cp_size = parallel_args_->cp_size();

std::string host;
int port;
net::parse_host_port_from_addr(master_addr, host, port);

#if defined(USE_NPU)
if (FLAGS_npu_kernel_backend == "ATB") {
return;
}
#endif
std::string host;
int port;
net::parse_host_port_from_addr(master_addr, host, port);

int global_rank = parallel_args_->rank();
int world_size = parallel_args_->world_size();
int dp_size = parallel_args_->dp_size();
int ep_size = parallel_args_->ep_size();
process_group_ = create_process_group(global_rank,
world_size,
world_size,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ ColumParallelLinearLoader::ColumParallelLinearLoader(
}

void ColumParallelLinearLoader::load_state_dict(const StateDict& state_dict) {
if (dp_size_ > 1) {
if (dp_size_ > 1 || cp_size_ > 1) {
set_weight(
state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_);
} else {
Expand Down
11 changes: 7 additions & 4 deletions xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,20 +28,23 @@ void NpuColumnParallelLinearImpl::param_from_args(

if (parallel_args.world_size() > 1) {
if (parallel_args.mapping_data().empty()) {
if (dp_size_ > 1) {
const bool use_local_tp = (dp_size_ > 1) || (cp_size_ > 1);
if (use_local_tp) {
param.tensorParallelInfo.rank = dp_local_tp_rank_;
param.tensorParallelInfo.worldSize = dp_local_tp_size_;
} else {
param.tensorParallelInfo.rank = parallel_args.rank();
param.tensorParallelInfo.worldSize = parallel_args.world_size();
}
param.parallelType = atb_speed::common::COLUMN_PARALLEL;
param.tensorParallelInfo.commDomain = std::to_string(dp_rank_);
param.tensorParallelInfo.backend = "lccl";
const int32_t tp_group_id =
use_local_tp ? (parallel_args.rank() / dp_local_tp_size_) : 0;
param.tensorParallelInfo.commDomain = std::to_string(tp_group_id);
param.tensorParallelInfo.backend = FLAGS_communication_backend;
} else {
param.parallelType = atb_speed::common::COLUMN_PARALLEL;
atb_speed::common::ParallelInfo parallelInfo =
parallel_args.mapping().Get(atb_speed::base::WORD_EMBED_TP);
parallel_args.mapping().Get(atb_speed::base::ATTN_TP);
param.tensorParallelInfo.rank = parallelInfo.rank;
param.tensorParallelInfo.worldSize = parallelInfo.rankIds.size();
param.tensorParallelInfo.backend = FLAGS_communication_backend;
Expand Down
37 changes: 27 additions & 10 deletions xllm/core/layers/npu/npu_lm_head_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param,
const ParallelArgs& parallel_args,
bool isPrefill) {
const bool use_column_parallel = cp_size_ > 1;
param.outputHidden = cp_size_ > 1;
param.unpadInputs = true;
param.gatherAhead = isPrefill;
param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads();
Expand Down Expand Up @@ -101,7 +102,7 @@ NpuLmHeadImpl::NpuLmHeadImpl(const ModelContext& context) : BaseLayer(context) {
false);

atb_weight_tensors_.resize(1);
atOutTensors_.resize(1);
atOutTensors_.resize(2);

auto options = context.get_tensor_options();
dtype_ = c10::typeMetaToScalarType(options.dtype());
Expand Down Expand Up @@ -146,26 +147,38 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node,
return -1;
}
node.inTensors.resize(node.operation->GetInputNum());
node.outTensors.resize(1);
node.outTensors.resize(node.operation->GetOutputNum());

node.inTensors.at(1) = &atb_weight_tensors_[0];

node.variantPack.inTensors.reserve(node.inTensors.size());
node.variantPack.inTensors.resize(node.inTensors.size());
node.variantPack.outTensors.reserve(1);
node.variantPack.outTensors.resize(1);
node.variantPack.outTensors.reserve(node.outTensors.size());
node.variantPack.outTensors.resize(node.outTensors.size());

return atb::NO_ERROR;
}

torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes,
int nodeId) {
torch::Tensor out_hidden;
return forward_with_hidden(hidden_states, seleted_idxes, out_hidden, nodeId);
}

torch::Tensor NpuLmHeadImpl::forward_with_hidden(
const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes,
torch::Tensor& out_hidden,
int nodeId) {
atb::Status st;
build_node_variant_pack(lm_head_node_prefill_, hidden_states, seleted_idxes);
st = execute_node(lm_head_node_prefill_, nodeId);
LOG_IF(FATAL, st != 0) << model_name_
<< "execute lmhead node fail, error code: " << st;
if (atOutTensors_.size() > 1) {
out_hidden = atOutTensors_[1];
}
return atOutTensors_[0];
}

Expand Down Expand Up @@ -212,12 +225,16 @@ void NpuLmHeadImpl::build_node_variant_pack(
inTensorDescs.at(8) = placeholder_.desc;

atb::Status st = node.operation->InferShape(inTensorDescs, outTensorDescs);
at::Tensor newTensor =
atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(0));

atOutTensors_.at(0) = newTensor;
node.variantPack.outTensors.at(0) =
atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(0));
LOG_IF(FATAL, st != atb::NO_ERROR)
<< model_name_ << " infer lmhead shape fail, error code: " << st;

atOutTensors_.resize(node.variantPack.outTensors.size());
for (size_t i = 0; i < node.variantPack.outTensors.size(); ++i) {
atOutTensors_.at(i) =
atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(i));
node.variantPack.outTensors.at(i) =
atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(i));
}
}

} // namespace layer
Expand Down
5 changes: 5 additions & 0 deletions xllm/core/layers/npu/npu_lm_head_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,11 @@ class NpuLmHeadImpl : public BaseLayer {
const torch::Tensor& seleted_idxes,
int nodeId);

torch::Tensor forward_with_hidden(const torch::Tensor& hidden_states,
const torch::Tensor& seleted_idxes,
torch::Tensor& out_hidden,
int nodeId);

private:
void param_from_args(atb_speed::common::LmHeadParam& param,
const ModelArgs& args,
Expand Down
7 changes: 7 additions & 0 deletions xllm/core/runtime/forward_params.h
Original file line number Diff line number Diff line change
Expand Up @@ -237,6 +237,9 @@ struct RawForwardInput {
// chunked prefill case of speculative decoding
// extra token ids for each sequence, and -1 for last chunk
std::vector<int32_t> extra_token_ids;
// precomputed shifted token ids for mtp prefill, aligned with
// flatten_tokens_vec at token level.
std::vector<int32_t> mtp_shifted_token_ids;
// embedding ids of each sequence
std::vector<int> embedding_ids;
// request ids of each sequence
Expand Down Expand Up @@ -382,6 +385,10 @@ struct RawForwardInput {
outputs.flatten_positions_vec =
gather_token_level_vector_i32(flatten_positions_vec);
}
if (!mtp_shifted_token_ids.empty()) {
outputs.mtp_shifted_token_ids =
gather_token_level_vector_i32(mtp_shifted_token_ids);
}

auto build_seq_lens = [&](const std::vector<int32_t>& original,
const std::vector<int32_t>& lengths) {
Expand Down
7 changes: 6 additions & 1 deletion xllm/core/runtime/forward_shared_memory_manager.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -280,6 +280,7 @@ size_t calculate_raw_forward_input_size(const RawForwardInput& input) {
total += get_vector_size(input.embedding_ids);
total += get_string_vector_size(input.request_ids);
total += get_vector_size(input.extra_token_ids);
total += get_vector_to_tensor_size(input.mtp_shifted_token_ids);
total += type_size<uint64_t> +
input.swap_blocks.size() * swap_block_info_fixed_size();
total += get_vector_to_tensor_size(input.src_block_indices);
Expand Down Expand Up @@ -1100,6 +1101,7 @@ inline void deserialize_raw_forward_input(const char*& buffer,
read_vector(buffer, input_params.embedding_ids, device_buffer);
read_string_vector(buffer, input_params.request_ids, device_buffer);
read_vector(buffer, input_params.extra_token_ids, device_buffer);
read_tensor(buffer, input_params.mtp_shifted_token_ids, device_buffer);
read_swap_blocks(buffer, input_params.swap_blocks, device_buffer);
read_tensor(buffer, input_params.src_block_indices, device_buffer);
read_tensor(buffer, input_params.dst_block_indices, device_buffer);
Expand Down Expand Up @@ -1189,6 +1191,7 @@ inline void serialize_raw_forward_input(const RawForwardInput& input,
write_vector(buffer, input.embedding_ids);
write_string_vector(buffer, input.request_ids);
write_vector(buffer, input.extra_token_ids);
write_vector_to_tensor(buffer, input.mtp_shifted_token_ids);
write_swap_blocks(buffer, input.swap_blocks);
write_vector_to_tensor(buffer, input.src_block_indices);
write_vector_to_tensor(buffer, input.dst_block_indices);
Expand Down Expand Up @@ -1414,6 +1417,8 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input,
input_params.swap_blocks = std::move(raw_input.swap_blocks);
input_params.batch_id = std::move(raw_input.batch_id);
input_params.extra_token_ids = std::move(raw_input.extra_token_ids);
input_params.mtp_shifted_token_ids =
torch::tensor(std::move(raw_input.mtp_shifted_token_ids), tensor_options);

input_params.new_cache_slot_offsets = torch::tensor(
std::move(raw_input.new_cache_slot_offsets), tensor_options);
Expand Down Expand Up @@ -1710,4 +1715,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) {
void ForwardSharedMemoryManager::clear() {
std::memset(base_address(), 0, size());
}
} // namespace xllm
} // namespace xllm
Loading