diff --git a/third_party/xllm_atb_layers b/third_party/xllm_atb_layers index 918c03d2a..29ab875c6 160000 --- a/third_party/xllm_atb_layers +++ b/third_party/xllm_atb_layers @@ -1 +1 @@ -Subproject commit 918c03d2abc4c9996196a797aefe743863b7e0ae +Subproject commit 29ab875c61f1e70f925d4472851d57c964667e55 diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 55d8e3c9b..b9a35a1ff 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -262,6 +262,7 @@ Master::Master(const Options& options, EngineType type) .dp_size(options.dp_size()) .ep_size(options.ep_size()) .enable_prefill_sp(options_.enable_prefill_sp()) + .cp_size(options.cp_size()) .enable_chunked_prefill(options_.enable_chunked_prefill()) .max_seqs_per_batch(options_.max_seqs_per_batch()) .max_tokens_per_chunk_for_prefill( diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 8f50424c7..65ace2266 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -69,11 +69,13 @@ BatchInputBuilder::BatchInputBuilder( batch_id_(batch_id), cp_size_(std::max(1, cp_size)) { // Reserve space for better performance - state_.flatten_tokens_vec.reserve(1000); - state_.flatten_positions_vec.reserve(1000); + const size_t reserve_size = 1024; + state_.flatten_tokens_vec.reserve(reserve_size); + state_.flatten_positions_vec.reserve(reserve_size); state_.mrope_positions_vec.reserve(sequences.size()); state_.block_tables_vec.reserve(sequences.size()); state_.acc_logprob_vec.reserve(sequences.size()); + state_.mtp_shifted_token_ids.reserve(reserve_size); if (args_ != nullptr) { use_mrope_ = (args_->rope_scaling_rope_type() == "mrope"); } @@ -241,6 +243,9 @@ void BatchInputBuilder::process_sequences_multithreaded() { state_.extra_token_ids.insert(state_.extra_token_ids.end(), state.extra_token_ids.begin(), state.extra_token_ids.end()); + state_.mtp_shifted_token_ids.insert(state_.mtp_shifted_token_ids.end(), + state.mtp_shifted_token_ids.begin(), + state.mtp_shifted_token_ids.end()); state_.transfer_kv_infos.insert(state_.transfer_kv_infos.end(), state.transfer_kv_infos.begin(), state.transfer_kv_infos.end()); @@ -337,6 +342,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, uint32_t padded_seq_len, BuilderState* state_ptr) { BuilderState& state = state_ptr ? *state_ptr : state_; + const size_t seq_token_begin = state.flatten_tokens_vec.size(); const auto& token_ids = sequence->tokens(); const uint32_t n_tokens = token_ids.size(); @@ -394,6 +400,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, } // Add extra token id + int32_t extra_token_id = -1; if (n_tokens == seq_len) { // last chunk of prefill and decode // add -1 as extra token id @@ -401,7 +408,25 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, state.embedding_ids.emplace_back(sequence->get_embedding_id()); state.request_ids.emplace_back(sequence->request_id()); } else { - state.extra_token_ids.emplace_back(token_ids[seq_len]); + extra_token_id = token_ids[seq_len]; + state.extra_token_ids.emplace_back(extra_token_id); + } + + if (cp_size_ > 1 && state.batch_forward_type.is_prefill()) { + const uint32_t q_len = seq_len - n_kv_cache_tokens; + if (q_len > 1) { + state.mtp_shifted_token_ids.insert( + state.mtp_shifted_token_ids.end(), + state.flatten_tokens_vec.begin() + seq_token_begin + 1, + state.flatten_tokens_vec.begin() + seq_token_begin + q_len); + } + state.mtp_shifted_token_ids.emplace_back(extra_token_id); + if (padded_seq_len > seq_len) { + const int32_t pad_token_id = args_ ? args_->pad_token_id() : 0; + state.mtp_shifted_token_ids.insert(state.mtp_shifted_token_ids.end(), + padded_seq_len - seq_len, + pad_token_id); + } } } @@ -588,6 +613,10 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.embedding_ids = std::move(state_.embedding_ids); input_params.request_ids = std::move(state_.request_ids); input_params.extra_token_ids = std::move(state_.extra_token_ids); + if (!state_.mtp_shifted_token_ids.empty()) { + input_params.mtp_shifted_token_ids = + torch::tensor(state_.mtp_shifted_token_ids, torch::kInt); + } if (swap_block_transfer_infos_ != nullptr && swap_block_transfer_infos_->size() > 0) { @@ -668,6 +697,8 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.embedding_ids = std::move(state_.embedding_ids); raw_forward_input.request_ids = std::move(state_.request_ids); raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids); + raw_forward_input.mtp_shifted_token_ids = + std::move(state_.mtp_shifted_token_ids); // beam search kernel input if (state_.acc_logprob_vec.size() > 0) { raw_forward_input.acc_logprob_vec = std::move(state_.acc_logprob_vec); diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 86b3eb18a..476dcd98c 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -104,6 +104,7 @@ class BatchInputBuilder { std::vector embedding_ids; std::vector request_ids; std::vector extra_token_ids; + std::vector mtp_shifted_token_ids; std::vector transfer_kv_infos; // for continuous kvcache diff --git a/xllm/core/framework/model/causal_lm.h b/xllm/core/framework/model/causal_lm.h index 3aa7a00ca..67aaebf11 100644 --- a/xllm/core/framework/model/causal_lm.h +++ b/xllm/core/framework/model/causal_lm.h @@ -74,6 +74,17 @@ class CausalLM : public torch::nn::Module { virtual torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) = 0; + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // out_hidden: [num_selected_tokens, hidden_size] + // returns: [num_selected_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) { + NOT_IMPLEMENTED(); + return torch::Tensor(); + } + virtual void load_model(std::unique_ptr loader) = 0; virtual torch::Device device() const = 0; @@ -158,6 +169,16 @@ class CausalLMImpl : public CausalLM { return model_->logits(hidden_states, seleted_idxes); } + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) override { + if constexpr (detail::has_logits_with_hidden::value) { + return model_->logits(hidden_states, seleted_idxes, out_hidden); + } else { + return CausalLM::logits(hidden_states, seleted_idxes, out_hidden); + } + } + void load_model(std::unique_ptr loader) override { model_->load_model(std::move(loader)); } diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 85c945df6..55d8b6354 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -351,6 +351,7 @@ struct ModelInputParams { params.embedding_ids = std::move(embedding_ids); params.request_ids = std::move(request_ids); params.extra_token_ids = std::move(extra_token_ids); + params.mtp_shifted_token_ids = safe_to(mtp_shifted_token_ids, device, true); params.dp_ep_padding_data = dp_ep_padding_data; params.cp_ep_padding_data .attn_padding_idx( @@ -535,6 +536,8 @@ struct ModelInputParams { // chunked prefill case of speculative decoding // extra token ids for each sequence, and -1 for last chunk std::vector extra_token_ids; + // precomputed shifted token ids for mtp prefill. + torch::Tensor mtp_shifted_token_ids; // swap std::vector swap_blocks; diff --git a/xllm/core/framework/model/model_traits.h b/xllm/core/framework/model/model_traits.h index 9e351819f..d36b75da2 100644 --- a/xllm/core/framework/model/model_traits.h +++ b/xllm/core/framework/model/model_traits.h @@ -67,6 +67,17 @@ struct has_set_word_embedding< std::void_t()->set_word_embedding( std::declval()))>> : std::true_type {}; +template +struct has_logits_with_hidden : std::false_type {}; + +template +struct has_logits_with_hidden()->logits( + std::declval(), + std::declval(), + std::declval()))>> + : std::true_type {}; + template struct has_lazy_load_model : std::false_type {}; diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index ac9faf6b1..0fa4c50fc 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -119,19 +119,22 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, void CollectiveCommunicator::create_process_groups( const std::string& master_addr, const torch::Device& device) { + int global_rank = parallel_args_->rank(); + int world_size = parallel_args_->world_size(); + int dp_size = parallel_args_->dp_size(); + int ep_size = parallel_args_->ep_size(); + int cp_size = parallel_args_->cp_size(); + + std::string host; + int port; + net::parse_host_port_from_addr(master_addr, host, port); + #if defined(USE_NPU) if (FLAGS_npu_kernel_backend == "ATB") { return; } #endif - std::string host; - int port; - net::parse_host_port_from_addr(master_addr, host, port); - int global_rank = parallel_args_->rank(); - int world_size = parallel_args_->world_size(); - int dp_size = parallel_args_->dp_size(); - int ep_size = parallel_args_->ep_size(); process_group_ = create_process_group(global_rank, world_size, world_size, diff --git a/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp b/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp index 35836e916..7c9b16c2b 100644 --- a/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp +++ b/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp @@ -28,7 +28,7 @@ ColumParallelLinearLoader::ColumParallelLinearLoader( } void ColumParallelLinearLoader::load_state_dict(const StateDict& state_dict) { - if (dp_size_ > 1) { + if (dp_size_ > 1 || cp_size_ > 1) { set_weight( state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); } else { diff --git a/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp b/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp index 51292671b..7b7537c83 100644 --- a/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp +++ b/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp @@ -28,7 +28,8 @@ void NpuColumnParallelLinearImpl::param_from_args( if (parallel_args.world_size() > 1) { if (parallel_args.mapping_data().empty()) { - if (dp_size_ > 1) { + const bool use_local_tp = (dp_size_ > 1) || (cp_size_ > 1); + if (use_local_tp) { param.tensorParallelInfo.rank = dp_local_tp_rank_; param.tensorParallelInfo.worldSize = dp_local_tp_size_; } else { @@ -36,12 +37,14 @@ void NpuColumnParallelLinearImpl::param_from_args( param.tensorParallelInfo.worldSize = parallel_args.world_size(); } param.parallelType = atb_speed::common::COLUMN_PARALLEL; - param.tensorParallelInfo.commDomain = std::to_string(dp_rank_); - param.tensorParallelInfo.backend = "lccl"; + const int32_t tp_group_id = + use_local_tp ? (parallel_args.rank() / dp_local_tp_size_) : 0; + param.tensorParallelInfo.commDomain = std::to_string(tp_group_id); + param.tensorParallelInfo.backend = FLAGS_communication_backend; } else { param.parallelType = atb_speed::common::COLUMN_PARALLEL; atb_speed::common::ParallelInfo parallelInfo = - parallel_args.mapping().Get(atb_speed::base::WORD_EMBED_TP); + parallel_args.mapping().Get(atb_speed::base::ATTN_TP); param.tensorParallelInfo.rank = parallelInfo.rank; param.tensorParallelInfo.worldSize = parallelInfo.rankIds.size(); param.tensorParallelInfo.backend = FLAGS_communication_backend; diff --git a/xllm/core/layers/npu/npu_lm_head_impl.cpp b/xllm/core/layers/npu/npu_lm_head_impl.cpp index 61c540e78..2bfd6d4fd 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.cpp +++ b/xllm/core/layers/npu/npu_lm_head_impl.cpp @@ -27,6 +27,7 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, bool isPrefill) { + param.outputHidden = cp_size_ > 1; param.unpadInputs = true; param.gatherAhead = isPrefill; param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads(); @@ -100,7 +101,7 @@ NpuLmHeadImpl::NpuLmHeadImpl(const ModelContext& context) : BaseLayer(context) { false); atb_weight_tensors_.resize(1); - atOutTensors_.resize(1); + atOutTensors_.resize(2); auto options = context.get_tensor_options(); dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -145,14 +146,14 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node, return -1; } node.inTensors.resize(node.operation->GetInputNum()); - node.outTensors.resize(1); + node.outTensors.resize(node.operation->GetOutputNum()); node.inTensors.at(1) = &atb_weight_tensors_[0]; node.variantPack.inTensors.reserve(node.inTensors.size()); node.variantPack.inTensors.resize(node.inTensors.size()); - node.variantPack.outTensors.reserve(1); - node.variantPack.outTensors.resize(1); + node.variantPack.outTensors.reserve(node.outTensors.size()); + node.variantPack.outTensors.resize(node.outTensors.size()); return atb::NO_ERROR; } @@ -160,6 +161,15 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node, torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes, int nodeId) { + torch::Tensor out_hidden; + return forward_with_hidden(hidden_states, seleted_idxes, out_hidden, nodeId); +} + +torch::Tensor NpuLmHeadImpl::forward_with_hidden( + const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden, + int nodeId) { atb::Status st; build_node_variant_pack(lm_head_node_prefill_, hidden_states, seleted_idxes); st = execute_node(lm_head_node_prefill_, nodeId); @@ -169,6 +179,9 @@ torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states, if (padded_vocab_size_ > vocab_size_ && vocab_size_ > 0) { output = output.slice(/*dim=*/-1, /*start=*/0, /*end=*/vocab_size_); } + if (atOutTensors_.size() > 1) { + out_hidden = atOutTensors_[1]; + } return output; } @@ -215,12 +228,16 @@ void NpuLmHeadImpl::build_node_variant_pack( inTensorDescs.at(8) = placeholder_.desc; atb::Status st = node.operation->InferShape(inTensorDescs, outTensorDescs); - at::Tensor newTensor = - atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(0)); - - atOutTensors_.at(0) = newTensor; - node.variantPack.outTensors.at(0) = - atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(0)); + LOG_IF(FATAL, st != atb::NO_ERROR) + << model_name_ << " infer lmhead shape fail, error code: " << st; + + atOutTensors_.resize(node.variantPack.outTensors.size()); + for (size_t i = 0; i < node.variantPack.outTensors.size(); ++i) { + atOutTensors_.at(i) = + atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(i)); + node.variantPack.outTensors.at(i) = + atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(i)); + } } } // namespace layer diff --git a/xllm/core/layers/npu/npu_lm_head_impl.h b/xllm/core/layers/npu/npu_lm_head_impl.h index bd3246f67..711f2d7f3 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.h +++ b/xllm/core/layers/npu/npu_lm_head_impl.h @@ -53,6 +53,11 @@ class NpuLmHeadImpl : public BaseLayer { const torch::Tensor& seleted_idxes, int nodeId); + torch::Tensor forward_with_hidden(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden, + int nodeId); + private: void param_from_args(atb_speed::common::LmHeadParam& param, const ModelArgs& args, diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 07515e14d..62d88cba1 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -237,6 +237,9 @@ struct RawForwardInput { // chunked prefill case of speculative decoding // extra token ids for each sequence, and -1 for last chunk std::vector extra_token_ids; + // precomputed shifted token ids for mtp prefill, aligned with + // flatten_tokens_vec at token level. + std::vector mtp_shifted_token_ids; // embedding ids of each sequence std::vector embedding_ids; // request ids of each sequence @@ -382,6 +385,10 @@ struct RawForwardInput { outputs.flatten_positions_vec = gather_token_level_vector_i32(flatten_positions_vec); } + if (!mtp_shifted_token_ids.empty()) { + outputs.mtp_shifted_token_ids = + gather_token_level_vector_i32(mtp_shifted_token_ids); + } auto build_seq_lens = [&](const std::vector& original, const std::vector& lengths) { diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 8be77cff4..4ddb01bfe 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -281,6 +281,7 @@ size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += get_vector_size(input.embedding_ids); total += get_string_vector_size(input.request_ids); total += get_vector_size(input.extra_token_ids); + total += get_vector_to_tensor_size(input.mtp_shifted_token_ids); total += type_size + input.swap_blocks.size() * swap_block_info_fixed_size(); total += get_vector_to_tensor_size(input.src_block_indices); @@ -1112,6 +1113,7 @@ inline void deserialize_raw_forward_input(const char*& buffer, read_vector(buffer, input_params.embedding_ids, device_buffer); read_string_vector(buffer, input_params.request_ids, device_buffer); read_vector(buffer, input_params.extra_token_ids, device_buffer); + read_tensor(buffer, input_params.mtp_shifted_token_ids, device_buffer); read_swap_blocks(buffer, input_params.swap_blocks, device_buffer); read_tensor(buffer, input_params.src_block_indices, device_buffer); read_tensor(buffer, input_params.dst_block_indices, device_buffer); @@ -1201,6 +1203,7 @@ inline void serialize_raw_forward_input(const RawForwardInput& input, write_vector(buffer, input.embedding_ids); write_string_vector(buffer, input.request_ids); write_vector(buffer, input.extra_token_ids); + write_vector_to_tensor(buffer, input.mtp_shifted_token_ids); write_swap_blocks(buffer, input.swap_blocks); write_vector_to_tensor(buffer, input.src_block_indices); write_vector_to_tensor(buffer, input.dst_block_indices); @@ -1426,6 +1429,8 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.swap_blocks = std::move(raw_input.swap_blocks); input_params.batch_id = std::move(raw_input.batch_id); input_params.extra_token_ids = std::move(raw_input.extra_token_ids); + input_params.mtp_shifted_token_ids = + torch::tensor(std::move(raw_input.mtp_shifted_token_ids), tensor_options); input_params.new_cache_slot_offsets = torch::tensor( std::move(raw_input.new_cache_slot_offsets), tensor_options); @@ -1722,4 +1727,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) { void ForwardSharedMemoryManager::clear() { std::memset(base_address(), 0, size()); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 8033067f4..0a9531366 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -124,9 +124,16 @@ std::optional LLMWorkerImpl::step_internal( } torch::Tensor logits; + torch::Tensor selected_hidden_from_lm_head; if (sampling_params.selected_token_idxes.defined()) { - logits = model_->logits(model_output.hidden_states, - sampling_params.selected_token_idxes); + if (options_.cp_size() > 1) { + logits = model_->logits(model_output.hidden_states, + sampling_params.selected_token_idxes, + selected_hidden_from_lm_head); + } else { + logits = model_->logits(model_output.hidden_states, + sampling_params.selected_token_idxes); + } } ForwardOutput output; @@ -198,8 +205,15 @@ std::optional LLMWorkerImpl::step_internal( if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) { output.sample_output.embeddings = embeddings; } else if (sampling_params.selected_token_idxes.defined()) { - output.sample_output.embeddings = embeddings.index_select( - /*dim=*/0, sampling_params.selected_token_idxes); + if (options_.cp_size() > 1) { + CHECK(selected_hidden_from_lm_head.defined()) + << "selected_hidden_from_lm_head must be defined when " + "selected_token_idxes is defined."; + output.sample_output.embeddings = selected_hidden_from_lm_head; + } else { + output.sample_output.embeddings = embeddings.index_select( + /*dim=*/0, sampling_params.selected_token_idxes); + } } } diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index 75f2e3af9..d42002d74 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -323,7 +323,6 @@ std::optional MTPWorkerImpl::step_prefill( auto mask = (token_ids == -1); token_ids.masked_scatter_(mask, next_tokens); } - // generate kv cache for draft model timer.reset(); auto draft_future = draft_impl_->step_async(prefill_input); @@ -351,6 +350,14 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, ForwardInput& prefill_input) { prefill_input = input.to(device_, dtype_); auto& input_params = prefill_input.input_params; + if (options_.cp_size() > 1) { + CHECK(input_params.mtp_shifted_token_ids.defined()); + CHECK_EQ(input_params.mtp_shifted_token_ids.numel(), + prefill_input.token_ids.numel()); + prefill_input.token_ids = input_params.mtp_shifted_token_ids; + return; + } + auto& extra_token_ids = input_params.extra_token_ids; torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 6b67acb8e..6c7bfbd97 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -131,6 +131,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector extra_token_ids = std::vector(pb_forward_input->extra_token_ids().begin(), pb_forward_input->extra_token_ids().end()); + std::vector mtp_shifted_token_ids = + std::vector(pb_forward_input->mtp_shifted_token_ids().begin(), + pb_forward_input->mtp_shifted_token_ids().end()); std::vector request_ids = std::vector(pb_forward_input->request_ids().begin(), pb_forward_input->request_ids().end()); @@ -227,6 +230,8 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.embedding_ids = std::move(embedding_ids); input_params.request_ids = std::move(request_ids); input_params.extra_token_ids = std::move(extra_token_ids); + input_params.mtp_shifted_token_ids = + torch::tensor(mtp_shifted_token_ids, tensor_options); input_params.swap_blocks = std::move(swap_blocks); // block copy kernel @@ -505,6 +510,8 @@ void forward_input_to_proto(const RawForwardInput& inputs, inputs.request_ids); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_extra_token_ids(), inputs.extra_token_ids); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_mtp_shifted_token_ids(), + inputs.mtp_shifted_token_ids); pb_forward_input->mutable_swap_blocks()->Reserve(inputs.swap_blocks.size()); for (auto t : inputs.swap_blocks) { proto::BlockTransferInfo block_transfer_info; diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 8b9e48abd..b0f3c9603 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -557,13 +557,14 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, CpPrefillInputs tmp_cp_inputs; if (parallel_args_.cp_size() > 1 && input.input_params.batch_forward_type.is_prefill()) { - tmp_cp_inputs = prepare_cp_prefill_inputs(parallel_args_.cp_size(), - input.token_ids, - input.positions, - input.input_params.q_seq_lens); + tmp_cp_inputs = + prepare_cp_prefill_inputs(parallel_args_.cp_size(), + input.token_ids.cpu(), + input.positions.cpu(), + input.input_params.q_seq_lens.cpu()); processed_input.input_params.cp_prefill_inputs = tmp_cp_inputs.to(device_); CpEpPadding cp_ep_padding( - input.token_ids, + input.token_ids.cpu(), context_.get_model_args().num_experts_per_tok(), context_.get_parallel_args().mapping_data(), /*device=*/device_, diff --git a/xllm/models/llm/npu/llm_model_base.h b/xllm/models/llm/npu/llm_model_base.h index 6cd8fd999..22189c394 100644 --- a/xllm/models/llm/npu/llm_model_base.h +++ b/xllm/models/llm/npu/llm_model_base.h @@ -420,6 +420,17 @@ class LlmForCausalLMImplBase : public torch::nn::Module { return npu_lm_head_(hidden_states, seleted_idxes, 0); } + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // out_hidden: [num_seqs, hidden_size] + // returns: [num_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) { + return npu_lm_head_->forward_with_hidden( + hidden_states, seleted_idxes, out_hidden, 0); + } + // hidden_states: [num_tokens, hidden_size] // seleted_idxes: [num_tokens] // returns: [num_seqs, hidden_size] diff --git a/xllm/models/llm/npu/mtp_model_base.h b/xllm/models/llm/npu/mtp_model_base.h index 07f9ffccc..d52fe4ba0 100644 --- a/xllm/models/llm/npu/mtp_model_base.h +++ b/xllm/models/llm/npu/mtp_model_base.h @@ -294,6 +294,17 @@ class MtpForCausalLMImplBase : public torch::nn::Module { return lm_head_(hidden_states, seleted_idxes, 0); } + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // out_hidden: [num_seqs, hidden_size] + // returns: [num_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) { + return lm_head_->forward_with_hidden( + hidden_states, seleted_idxes, out_hidden, 0); + } + // hidden_states: [num_tokens, hidden_size] // seleted_idxes: [num_tokens] // returns: [num_seqs, hidden_size] diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 92443e2f8..39075cd73 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -241,6 +241,7 @@ message ForwardInput { repeated int32 dp_is_decode = 42; repeated int32 kv_cache_tokens_nums = 43; repeated string request_ids = 44; + repeated int32 mtp_shifted_token_ids = 45; } message BatchedForwardInputs {