From c829ea18565999a666fdb9b97bdaf7d78cdbf304 Mon Sep 17 00:00:00 2001 From: shifengmin Date: Mon, 30 Mar 2026 22:15:05 +0800 Subject: [PATCH 1/6] bugfix: CP compatibility with MTP --- xllm/core/distributed_runtime/master.cpp | 1 + .../framework/batch/batch_input_builder.cpp | 35 +++++++++++++++++- .../framework/batch/batch_input_builder.h | 1 + .../core/framework/model/model_input_params.h | 3 ++ .../collective_communicator.cpp | 37 +++++++++++++++---- .../parallel_state/collective_communicator.h | 1 + .../framework/parallel_state/parallel_args.h | 2 + .../loader/column_parallel_linear_loader.cpp | 2 +- .../npu/npu_column_parallel_linear_impl.cpp | 11 ++++-- xllm/core/runtime/forward_params.h | 9 +++++ .../runtime/forward_shared_memory_manager.cpp | 7 +++- xllm/core/runtime/llm_worker_impl.cpp | 9 +++++ xllm/core/runtime/mtp_worker_impl.cpp | 13 +++++++ xllm/core/runtime/params_utils.cpp | 7 ++++ xllm/core/runtime/worker_impl.cpp | 8 ++-- xllm/proto/worker.proto | 1 + 16 files changed, 129 insertions(+), 18 deletions(-) diff --git a/xllm/core/distributed_runtime/master.cpp b/xllm/core/distributed_runtime/master.cpp index 55d8e3c9b..b9a35a1ff 100644 --- a/xllm/core/distributed_runtime/master.cpp +++ b/xllm/core/distributed_runtime/master.cpp @@ -262,6 +262,7 @@ Master::Master(const Options& options, EngineType type) .dp_size(options.dp_size()) .ep_size(options.ep_size()) .enable_prefill_sp(options_.enable_prefill_sp()) + .cp_size(options.cp_size()) .enable_chunked_prefill(options_.enable_chunked_prefill()) .max_seqs_per_batch(options_.max_seqs_per_batch()) .max_tokens_per_chunk_for_prefill( diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 8f50424c7..68a4878d2 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -74,6 +74,7 @@ BatchInputBuilder::BatchInputBuilder( state_.mrope_positions_vec.reserve(sequences.size()); state_.block_tables_vec.reserve(sequences.size()); state_.acc_logprob_vec.reserve(sequences.size()); + state_.mtp_shifted_token_ids.reserve(1000); if (args_ != nullptr) { use_mrope_ = (args_->rope_scaling_rope_type() == "mrope"); } @@ -241,6 +242,9 @@ void BatchInputBuilder::process_sequences_multithreaded() { state_.extra_token_ids.insert(state_.extra_token_ids.end(), state.extra_token_ids.begin(), state.extra_token_ids.end()); + state_.mtp_shifted_token_ids.insert(state_.mtp_shifted_token_ids.end(), + state.mtp_shifted_token_ids.begin(), + state.mtp_shifted_token_ids.end()); state_.transfer_kv_infos.insert(state_.transfer_kv_infos.end(), state.transfer_kv_infos.begin(), state.transfer_kv_infos.end()); @@ -337,6 +341,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, uint32_t padded_seq_len, BuilderState* state_ptr) { BuilderState& state = state_ptr ? *state_ptr : state_; + const size_t seq_token_begin = state.flatten_tokens_vec.size(); const auto& token_ids = sequence->tokens(); const uint32_t n_tokens = token_ids.size(); @@ -394,6 +399,7 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, } // Add extra token id + int32_t extra_token_id = -1; if (n_tokens == seq_len) { // last chunk of prefill and decode // add -1 as extra token id @@ -401,7 +407,21 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, state.embedding_ids.emplace_back(sequence->get_embedding_id()); state.request_ids.emplace_back(sequence->request_id()); } else { - state.extra_token_ids.emplace_back(token_ids[seq_len]); + extra_token_id = token_ids[seq_len]; + state.extra_token_ids.emplace_back(extra_token_id); + } + + if (state.batch_forward_type.is_prefill()) { + const size_t seq_token_end = state.flatten_tokens_vec.size(); + const size_t seq_token_len = seq_token_end - seq_token_begin; + CHECK_GT(seq_token_len, 0); + if (seq_token_len > 1) { + state.mtp_shifted_token_ids.insert( + state.mtp_shifted_token_ids.end(), + state.flatten_tokens_vec.begin() + seq_token_begin + 1, + state.flatten_tokens_vec.begin() + seq_token_end); + } + state.mtp_shifted_token_ids.emplace_back(extra_token_id); } } @@ -588,6 +608,12 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.embedding_ids = std::move(state_.embedding_ids); input_params.request_ids = std::move(state_.request_ids); input_params.extra_token_ids = std::move(state_.extra_token_ids); + if (!state_.mtp_shifted_token_ids.empty()) { + CHECK_EQ(state_.mtp_shifted_token_ids.size(), state_.flatten_tokens_vec.size()) + << "mtp_shifted_token_ids size should match flatten token size"; + input_params.mtp_shifted_token_ids = + torch::tensor(state_.mtp_shifted_token_ids, torch::kInt); + } if (swap_block_transfer_infos_ != nullptr && swap_block_transfer_infos_->size() > 0) { @@ -668,6 +694,13 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.embedding_ids = std::move(state_.embedding_ids); raw_forward_input.request_ids = std::move(state_.request_ids); raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids); + if (!state_.mtp_shifted_token_ids.empty()) { + CHECK_EQ(state_.mtp_shifted_token_ids.size(), + raw_forward_input.flatten_tokens_vec.size()) + << "mtp_shifted_token_ids size should match flatten token size"; + } + raw_forward_input.mtp_shifted_token_ids = + std::move(state_.mtp_shifted_token_ids); // beam search kernel input if (state_.acc_logprob_vec.size() > 0) { raw_forward_input.acc_logprob_vec = std::move(state_.acc_logprob_vec); diff --git a/xllm/core/framework/batch/batch_input_builder.h b/xllm/core/framework/batch/batch_input_builder.h index 86b3eb18a..476dcd98c 100644 --- a/xllm/core/framework/batch/batch_input_builder.h +++ b/xllm/core/framework/batch/batch_input_builder.h @@ -104,6 +104,7 @@ class BatchInputBuilder { std::vector embedding_ids; std::vector request_ids; std::vector extra_token_ids; + std::vector mtp_shifted_token_ids; std::vector transfer_kv_infos; // for continuous kvcache diff --git a/xllm/core/framework/model/model_input_params.h b/xllm/core/framework/model/model_input_params.h index 85c945df6..55d8b6354 100644 --- a/xllm/core/framework/model/model_input_params.h +++ b/xllm/core/framework/model/model_input_params.h @@ -351,6 +351,7 @@ struct ModelInputParams { params.embedding_ids = std::move(embedding_ids); params.request_ids = std::move(request_ids); params.extra_token_ids = std::move(extra_token_ids); + params.mtp_shifted_token_ids = safe_to(mtp_shifted_token_ids, device, true); params.dp_ep_padding_data = dp_ep_padding_data; params.cp_ep_padding_data .attn_padding_idx( @@ -535,6 +536,8 @@ struct ModelInputParams { // chunked prefill case of speculative decoding // extra token ids for each sequence, and -1 for last chunk std::vector extra_token_ids; + // precomputed shifted token ids for mtp prefill. + torch::Tensor mtp_shifted_token_ids; // swap std::vector swap_blocks; diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index ac9faf6b1..a381d2382 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -119,19 +119,42 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, void CollectiveCommunicator::create_process_groups( const std::string& master_addr, const torch::Device& device) { + + int global_rank = parallel_args_->rank(); + int world_size = parallel_args_->world_size(); + int dp_size = parallel_args_->dp_size(); + int ep_size = parallel_args_->ep_size(); + int cp_size = parallel_args_->cp_size(); + + std::string host; + int port; + net::parse_host_port_from_addr(master_addr, host, port); + #if defined(USE_NPU) + // ATB backend usually relies on ATB/HCCL communication. However CP+MTP + // needs torch process group collectives in worker runtime. + if (cp_size > 1) { + CHECK_EQ(world_size % cp_size, 0) + << "world_size must be divisible by cp_size."; + const int cp_group_count = world_size / cp_size; + int port_offset = global_rank % cp_group_count + 1; + cp_group_ = create_process_group(global_rank, + world_size, + cp_size, + port + port_offset, + true, + host, + "cp_group", + device); + parallel_args_->cp_group_ = cp_group_.get(); + port += cp_group_count; + } + if (FLAGS_npu_kernel_backend == "ATB") { return; } #endif - std::string host; - int port; - net::parse_host_port_from_addr(master_addr, host, port); - int global_rank = parallel_args_->rank(); - int world_size = parallel_args_->world_size(); - int dp_size = parallel_args_->dp_size(); - int ep_size = parallel_args_->ep_size(); process_group_ = create_process_group(global_rank, world_size, world_size, diff --git a/xllm/core/framework/parallel_state/collective_communicator.h b/xllm/core/framework/parallel_state/collective_communicator.h index 72d2cde13..49fb466bc 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.h +++ b/xllm/core/framework/parallel_state/collective_communicator.h @@ -48,6 +48,7 @@ class CollectiveCommunicator { // aliases ParallelArgs::sp_group_ to tp_group_ instead of constructing a // separate communicator, so this stays empty for now. std::unique_ptr sp_group_; + std::unique_ptr cp_group_; std::unique_ptr moe_tp_group_; std::unique_ptr moe_ep_group_; }; diff --git a/xllm/core/framework/parallel_state/parallel_args.h b/xllm/core/framework/parallel_state/parallel_args.h index 1e3cefaa7..3c323c2aa 100644 --- a/xllm/core/framework/parallel_state/parallel_args.h +++ b/xllm/core/framework/parallel_state/parallel_args.h @@ -128,6 +128,8 @@ struct ParallelArgs { // the same rank set during prefill, but it remains a separate handle so the // SP communication policy can evolve independently from TP. ProcessGroup* sp_group_ = nullptr; + // Context-parallel communication group for CP-specific runtime collectives. + ProcessGroup* cp_group_ = nullptr; ProcessGroup* moe_ep_group_ = nullptr; ProcessGroup* moe_tp_group_ = nullptr; }; diff --git a/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp b/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp index 35836e916..7c9b16c2b 100644 --- a/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp +++ b/xllm/core/layers/npu/loader/column_parallel_linear_loader.cpp @@ -28,7 +28,7 @@ ColumParallelLinearLoader::ColumParallelLinearLoader( } void ColumParallelLinearLoader::load_state_dict(const StateDict& state_dict) { - if (dp_size_ > 1) { + if (dp_size_ > 1 || cp_size_ > 1) { set_weight( state_dict, "weight", 0, 0, dp_local_tp_rank_, dp_local_tp_size_); } else { diff --git a/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp b/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp index 51292671b..7b7537c83 100644 --- a/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp +++ b/xllm/core/layers/npu/npu_column_parallel_linear_impl.cpp @@ -28,7 +28,8 @@ void NpuColumnParallelLinearImpl::param_from_args( if (parallel_args.world_size() > 1) { if (parallel_args.mapping_data().empty()) { - if (dp_size_ > 1) { + const bool use_local_tp = (dp_size_ > 1) || (cp_size_ > 1); + if (use_local_tp) { param.tensorParallelInfo.rank = dp_local_tp_rank_; param.tensorParallelInfo.worldSize = dp_local_tp_size_; } else { @@ -36,12 +37,14 @@ void NpuColumnParallelLinearImpl::param_from_args( param.tensorParallelInfo.worldSize = parallel_args.world_size(); } param.parallelType = atb_speed::common::COLUMN_PARALLEL; - param.tensorParallelInfo.commDomain = std::to_string(dp_rank_); - param.tensorParallelInfo.backend = "lccl"; + const int32_t tp_group_id = + use_local_tp ? (parallel_args.rank() / dp_local_tp_size_) : 0; + param.tensorParallelInfo.commDomain = std::to_string(tp_group_id); + param.tensorParallelInfo.backend = FLAGS_communication_backend; } else { param.parallelType = atb_speed::common::COLUMN_PARALLEL; atb_speed::common::ParallelInfo parallelInfo = - parallel_args.mapping().Get(atb_speed::base::WORD_EMBED_TP); + parallel_args.mapping().Get(atb_speed::base::ATTN_TP); param.tensorParallelInfo.rank = parallelInfo.rank; param.tensorParallelInfo.worldSize = parallelInfo.rankIds.size(); param.tensorParallelInfo.backend = FLAGS_communication_backend; diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 07515e14d..070737e22 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -237,6 +237,9 @@ struct RawForwardInput { // chunked prefill case of speculative decoding // extra token ids for each sequence, and -1 for last chunk std::vector extra_token_ids; + // precomputed shifted token ids for mtp prefill, aligned with + // flatten_tokens_vec at token level. + std::vector mtp_shifted_token_ids; // embedding ids of each sequence std::vector embedding_ids; // request ids of each sequence @@ -382,6 +385,12 @@ struct RawForwardInput { outputs.flatten_positions_vec = gather_token_level_vector_i32(flatten_positions_vec); } + if (!mtp_shifted_token_ids.empty()) { + CHECK_EQ(mtp_shifted_token_ids.size(), static_cast(token_num)) + << "mtp_shifted_token_ids size should match flatten_tokens_vec size"; + outputs.mtp_shifted_token_ids = + gather_token_level_vector_i32(mtp_shifted_token_ids); + } auto build_seq_lens = [&](const std::vector& original, const std::vector& lengths) { diff --git a/xllm/core/runtime/forward_shared_memory_manager.cpp b/xllm/core/runtime/forward_shared_memory_manager.cpp index 8be77cff4..4ddb01bfe 100644 --- a/xllm/core/runtime/forward_shared_memory_manager.cpp +++ b/xllm/core/runtime/forward_shared_memory_manager.cpp @@ -281,6 +281,7 @@ size_t calculate_raw_forward_input_size(const RawForwardInput& input) { total += get_vector_size(input.embedding_ids); total += get_string_vector_size(input.request_ids); total += get_vector_size(input.extra_token_ids); + total += get_vector_to_tensor_size(input.mtp_shifted_token_ids); total += type_size + input.swap_blocks.size() * swap_block_info_fixed_size(); total += get_vector_to_tensor_size(input.src_block_indices); @@ -1112,6 +1113,7 @@ inline void deserialize_raw_forward_input(const char*& buffer, read_vector(buffer, input_params.embedding_ids, device_buffer); read_string_vector(buffer, input_params.request_ids, device_buffer); read_vector(buffer, input_params.extra_token_ids, device_buffer); + read_tensor(buffer, input_params.mtp_shifted_token_ids, device_buffer); read_swap_blocks(buffer, input_params.swap_blocks, device_buffer); read_tensor(buffer, input_params.src_block_indices, device_buffer); read_tensor(buffer, input_params.dst_block_indices, device_buffer); @@ -1201,6 +1203,7 @@ inline void serialize_raw_forward_input(const RawForwardInput& input, write_vector(buffer, input.embedding_ids); write_string_vector(buffer, input.request_ids); write_vector(buffer, input.extra_token_ids); + write_vector_to_tensor(buffer, input.mtp_shifted_token_ids); write_swap_blocks(buffer, input.swap_blocks); write_vector_to_tensor(buffer, input.src_block_indices); write_vector_to_tensor(buffer, input.dst_block_indices); @@ -1426,6 +1429,8 @@ void convert_raw_forward_input_to_forward_input(RawForwardInput& raw_input, input_params.swap_blocks = std::move(raw_input.swap_blocks); input_params.batch_id = std::move(raw_input.batch_id); input_params.extra_token_ids = std::move(raw_input.extra_token_ids); + input_params.mtp_shifted_token_ids = + torch::tensor(std::move(raw_input.mtp_shifted_token_ids), tensor_options); input_params.new_cache_slot_offsets = torch::tensor( std::move(raw_input.new_cache_slot_offsets), tensor_options); @@ -1722,4 +1727,4 @@ void ForwardSharedMemoryManager::raw_output_read(RawForwardOutput& output) { void ForwardSharedMemoryManager::clear() { std::memset(base_address(), 0, size()); } -} // namespace xllm \ No newline at end of file +} // namespace xllm diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 8033067f4..899729fd5 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -31,6 +31,7 @@ limitations under the License. #include "core/common/global_flags.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" +#include "framework/parallel_state/parallel_state.h" #include "framework/state_dict/state_dict.h" #if defined(USE_CUDA) || defined(USE_ILU) || defined(USE_MUSA) #include "layers/cuda/flashinfer_workspace.h" @@ -198,6 +199,14 @@ std::optional LLMWorkerImpl::step_internal( if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) { output.sample_output.embeddings = embeddings; } else if (sampling_params.selected_token_idxes.defined()) { + if (context_.get_parallel_args().cp_size() > 1) { + auto* cp_group = context_.get_parallel_args().cp_group_; + CHECK(cp_group != nullptr) + << "cp_group_ must be initialized for CP+MTP embeddings gather."; + CHECK_EQ(cp_group->world_size(), context_.get_parallel_args().cp_size()) + << "cp_group_ world size mismatch with cp_size."; + embeddings = parallel_state::gather(embeddings, cp_group, /*dim=*/0); + } output.sample_output.embeddings = embeddings.index_select( /*dim=*/0, sampling_params.selected_token_idxes); } diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index 75f2e3af9..3aa574347 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -351,7 +351,18 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, ForwardInput& prefill_input) { prefill_input = input.to(device_, dtype_); auto& input_params = prefill_input.input_params; + if (input_params.mtp_shifted_token_ids.defined() && + input_params.mtp_shifted_token_ids.numel() > 0) { + CHECK_EQ(input_params.mtp_shifted_token_ids.numel(), + prefill_input.token_ids.numel()) + << "mtp shifted tokens numel mismatch with token_ids"; + prefill_input.token_ids = input_params.mtp_shifted_token_ids.clone(); + return; + } + auto& extra_token_ids = input_params.extra_token_ids; + CHECK_EQ(extra_token_ids.size(), static_cast(input_params.num_sequences)) + << "extra_token_ids size should match num_sequences"; torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); Slice tokens_ids_slice = { @@ -371,6 +382,8 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, tokens_ids_slice_i.end()); new_token_ids.emplace_back(extra_token_ids[i]); } + CHECK_EQ(static_cast(start_idx), tokens_ids_slice.size()) + << "q_seq_lens sum should equal token_ids numel"; prefill_input.token_ids = torch::tensor(new_token_ids, prefill_input.positions.options()); } diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 6b67acb8e..947d864b4 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -131,6 +131,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector extra_token_ids = std::vector(pb_forward_input->extra_token_ids().begin(), pb_forward_input->extra_token_ids().end()); + std::vector mtp_shifted_token_ids = std::vector( + pb_forward_input->mtp_shifted_token_ids().begin(), + pb_forward_input->mtp_shifted_token_ids().end()); std::vector request_ids = std::vector(pb_forward_input->request_ids().begin(), pb_forward_input->request_ids().end()); @@ -227,6 +230,8 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, input_params.embedding_ids = std::move(embedding_ids); input_params.request_ids = std::move(request_ids); input_params.extra_token_ids = std::move(extra_token_ids); + input_params.mtp_shifted_token_ids = + torch::tensor(mtp_shifted_token_ids, tensor_options); input_params.swap_blocks = std::move(swap_blocks); // block copy kernel @@ -505,6 +510,8 @@ void forward_input_to_proto(const RawForwardInput& inputs, inputs.request_ids); ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_extra_token_ids(), inputs.extra_token_ids); + ADD_VECTOR_TO_PROTO(pb_forward_input->mutable_mtp_shifted_token_ids(), + inputs.mtp_shifted_token_ids); pb_forward_input->mutable_swap_blocks()->Reserve(inputs.swap_blocks.size()); for (auto t : inputs.swap_blocks) { proto::BlockTransferInfo block_transfer_info; diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 8b9e48abd..5f4d2a6a8 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -558,12 +558,12 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, if (parallel_args_.cp_size() > 1 && input.input_params.batch_forward_type.is_prefill()) { tmp_cp_inputs = prepare_cp_prefill_inputs(parallel_args_.cp_size(), - input.token_ids, - input.positions, - input.input_params.q_seq_lens); + input.token_ids.cpu(), + input.positions.cpu(), + input.input_params.q_seq_lens.cpu()); processed_input.input_params.cp_prefill_inputs = tmp_cp_inputs.to(device_); CpEpPadding cp_ep_padding( - input.token_ids, + input.token_ids.cpu(), context_.get_model_args().num_experts_per_tok(), context_.get_parallel_args().mapping_data(), /*device=*/device_, diff --git a/xllm/proto/worker.proto b/xllm/proto/worker.proto index 92443e2f8..39075cd73 100644 --- a/xllm/proto/worker.proto +++ b/xllm/proto/worker.proto @@ -241,6 +241,7 @@ message ForwardInput { repeated int32 dp_is_decode = 42; repeated int32 kv_cache_tokens_nums = 43; repeated string request_ids = 44; + repeated int32 mtp_shifted_token_ids = 45; } message BatchedForwardInputs { From 79c16245f9f33bfe57dd7890a22c25272e7f5de3 Mon Sep 17 00:00:00 2001 From: shifengmin Date: Wed, 1 Apr 2026 14:44:13 +0800 Subject: [PATCH 2/6] bugfix: fix mtp token shift with CP padding --- .../framework/batch/batch_input_builder.cpp | 24 +++++++++---------- xllm/core/runtime/forward_params.h | 2 -- xllm/core/runtime/mtp_worker_impl.cpp | 7 +++--- 3 files changed, 14 insertions(+), 19 deletions(-) diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 68a4878d2..8c31fd240 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -411,17 +411,22 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, state.extra_token_ids.emplace_back(extra_token_id); } - if (state.batch_forward_type.is_prefill()) { - const size_t seq_token_end = state.flatten_tokens_vec.size(); - const size_t seq_token_len = seq_token_end - seq_token_begin; - CHECK_GT(seq_token_len, 0); - if (seq_token_len > 1) { + if (cp_size_ > 1 && state.batch_forward_type.is_prefill()) { + const uint32_t q_len = seq_len - n_kv_cache_tokens; + if (q_len > 1) { state.mtp_shifted_token_ids.insert( state.mtp_shifted_token_ids.end(), state.flatten_tokens_vec.begin() + seq_token_begin + 1, - state.flatten_tokens_vec.begin() + seq_token_end); + state.flatten_tokens_vec.begin() + seq_token_begin + q_len); } state.mtp_shifted_token_ids.emplace_back(extra_token_id); + if (padded_seq_len > seq_len) { + const int32_t pad_token_id = args_ ? args_->pad_token_id() : 0; + state.mtp_shifted_token_ids.insert( + state.mtp_shifted_token_ids.end(), + padded_seq_len - seq_len, + pad_token_id); + } } } @@ -609,8 +614,6 @@ ForwardInput BatchInputBuilder::state_to_forward_input() { input_params.request_ids = std::move(state_.request_ids); input_params.extra_token_ids = std::move(state_.extra_token_ids); if (!state_.mtp_shifted_token_ids.empty()) { - CHECK_EQ(state_.mtp_shifted_token_ids.size(), state_.flatten_tokens_vec.size()) - << "mtp_shifted_token_ids size should match flatten token size"; input_params.mtp_shifted_token_ids = torch::tensor(state_.mtp_shifted_token_ids, torch::kInt); } @@ -694,11 +697,6 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() { raw_forward_input.embedding_ids = std::move(state_.embedding_ids); raw_forward_input.request_ids = std::move(state_.request_ids); raw_forward_input.extra_token_ids = std::move(state_.extra_token_ids); - if (!state_.mtp_shifted_token_ids.empty()) { - CHECK_EQ(state_.mtp_shifted_token_ids.size(), - raw_forward_input.flatten_tokens_vec.size()) - << "mtp_shifted_token_ids size should match flatten token size"; - } raw_forward_input.mtp_shifted_token_ids = std::move(state_.mtp_shifted_token_ids); // beam search kernel input diff --git a/xllm/core/runtime/forward_params.h b/xllm/core/runtime/forward_params.h index 070737e22..62d88cba1 100644 --- a/xllm/core/runtime/forward_params.h +++ b/xllm/core/runtime/forward_params.h @@ -386,8 +386,6 @@ struct RawForwardInput { gather_token_level_vector_i32(flatten_positions_vec); } if (!mtp_shifted_token_ids.empty()) { - CHECK_EQ(mtp_shifted_token_ids.size(), static_cast(token_num)) - << "mtp_shifted_token_ids size should match flatten_tokens_vec size"; outputs.mtp_shifted_token_ids = gather_token_level_vector_i32(mtp_shifted_token_ids); } diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index 3aa574347..a7c477bfb 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -351,11 +351,10 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, ForwardInput& prefill_input) { prefill_input = input.to(device_, dtype_); auto& input_params = prefill_input.input_params; - if (input_params.mtp_shifted_token_ids.defined() && - input_params.mtp_shifted_token_ids.numel() > 0) { + if (options_.cp_size() > 1) { + CHECK(input_params.mtp_shifted_token_ids.defined()); CHECK_EQ(input_params.mtp_shifted_token_ids.numel(), - prefill_input.token_ids.numel()) - << "mtp shifted tokens numel mismatch with token_ids"; + prefill_input.token_ids.numel()); prefill_input.token_ids = input_params.mtp_shifted_token_ids.clone(); return; } From a1b21d986b713ef3a4a57324cd5418d18016f0d6 Mon Sep 17 00:00:00 2001 From: shifengmin Date: Wed, 1 Apr 2026 20:03:18 +0800 Subject: [PATCH 3/6] optimize: remove unnecessary clone --- xllm/core/framework/batch/batch_input_builder.cpp | 7 ++++--- xllm/core/runtime/mtp_worker_impl.cpp | 2 +- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index 8c31fd240..de228bb33 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -69,12 +69,13 @@ BatchInputBuilder::BatchInputBuilder( batch_id_(batch_id), cp_size_(std::max(1, cp_size)) { // Reserve space for better performance - state_.flatten_tokens_vec.reserve(1000); - state_.flatten_positions_vec.reserve(1000); + const size_t reserve_size = 1024; + state_.flatten_tokens_vec.reserve(reserve_size); + state_.flatten_positions_vec.reserve(reserve_size); state_.mrope_positions_vec.reserve(sequences.size()); state_.block_tables_vec.reserve(sequences.size()); state_.acc_logprob_vec.reserve(sequences.size()); - state_.mtp_shifted_token_ids.reserve(1000); + state_.mtp_shifted_token_ids.reserve(reserve_size); if (args_ != nullptr) { use_mrope_ = (args_->rope_scaling_rope_type() == "mrope"); } diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index a7c477bfb..fa2558052 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -355,7 +355,7 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, CHECK(input_params.mtp_shifted_token_ids.defined()); CHECK_EQ(input_params.mtp_shifted_token_ids.numel(), prefill_input.token_ids.numel()); - prefill_input.token_ids = input_params.mtp_shifted_token_ids.clone(); + prefill_input.token_ids = input_params.mtp_shifted_token_ids; return; } From 329adb0c4d36c659a7968af3921cc05151547d5b Mon Sep 17 00:00:00 2001 From: shifengmin Date: Fri, 3 Apr 2026 13:33:47 +0800 Subject: [PATCH 4/6] bugix: cp+mtp get hidden from lmhead --- xllm/core/framework/model/causal_lm.h | 21 +++++++++++ xllm/core/framework/model/model_traits.h | 11 ++++++ .../collective_communicator.cpp | 19 ---------- .../parallel_state/collective_communicator.h | 1 - .../framework/parallel_state/parallel_args.h | 2 - xllm/core/layers/npu/npu_lm_head_impl.cpp | 37 ++++++++++++++----- xllm/core/layers/npu/npu_lm_head_impl.h | 5 +++ xllm/core/runtime/llm_worker_impl.cpp | 29 +++++++++------ xllm/core/runtime/mtp_worker_impl.cpp | 4 +- xllm/models/llm/npu/llm_model_base.h | 11 ++++++ xllm/models/llm/npu/mtp_model_base.h | 11 ++++++ 11 files changed, 105 insertions(+), 46 deletions(-) diff --git a/xllm/core/framework/model/causal_lm.h b/xllm/core/framework/model/causal_lm.h index 3aa7a00ca..67aaebf11 100644 --- a/xllm/core/framework/model/causal_lm.h +++ b/xllm/core/framework/model/causal_lm.h @@ -74,6 +74,17 @@ class CausalLM : public torch::nn::Module { virtual torch::Tensor logits(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes) = 0; + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // out_hidden: [num_selected_tokens, hidden_size] + // returns: [num_selected_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) { + NOT_IMPLEMENTED(); + return torch::Tensor(); + } + virtual void load_model(std::unique_ptr loader) = 0; virtual torch::Device device() const = 0; @@ -158,6 +169,16 @@ class CausalLMImpl : public CausalLM { return model_->logits(hidden_states, seleted_idxes); } + torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) override { + if constexpr (detail::has_logits_with_hidden::value) { + return model_->logits(hidden_states, seleted_idxes, out_hidden); + } else { + return CausalLM::logits(hidden_states, seleted_idxes, out_hidden); + } + } + void load_model(std::unique_ptr loader) override { model_->load_model(std::move(loader)); } diff --git a/xllm/core/framework/model/model_traits.h b/xllm/core/framework/model/model_traits.h index 9e351819f..f0a2571ff 100644 --- a/xllm/core/framework/model/model_traits.h +++ b/xllm/core/framework/model/model_traits.h @@ -67,6 +67,17 @@ struct has_set_word_embedding< std::void_t()->set_word_embedding( std::declval()))>> : std::true_type {}; +template +struct has_logits_with_hidden : std::false_type {}; + +template +struct has_logits_with_hidden< + T, + std::void_t()->logits( + std::declval(), + std::declval(), + std::declval()))>> : std::true_type {}; + template struct has_lazy_load_model : std::false_type {}; diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index a381d2382..aa97aa7ce 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -131,25 +131,6 @@ void CollectiveCommunicator::create_process_groups( net::parse_host_port_from_addr(master_addr, host, port); #if defined(USE_NPU) - // ATB backend usually relies on ATB/HCCL communication. However CP+MTP - // needs torch process group collectives in worker runtime. - if (cp_size > 1) { - CHECK_EQ(world_size % cp_size, 0) - << "world_size must be divisible by cp_size."; - const int cp_group_count = world_size / cp_size; - int port_offset = global_rank % cp_group_count + 1; - cp_group_ = create_process_group(global_rank, - world_size, - cp_size, - port + port_offset, - true, - host, - "cp_group", - device); - parallel_args_->cp_group_ = cp_group_.get(); - port += cp_group_count; - } - if (FLAGS_npu_kernel_backend == "ATB") { return; } diff --git a/xllm/core/framework/parallel_state/collective_communicator.h b/xllm/core/framework/parallel_state/collective_communicator.h index 49fb466bc..72d2cde13 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.h +++ b/xllm/core/framework/parallel_state/collective_communicator.h @@ -48,7 +48,6 @@ class CollectiveCommunicator { // aliases ParallelArgs::sp_group_ to tp_group_ instead of constructing a // separate communicator, so this stays empty for now. std::unique_ptr sp_group_; - std::unique_ptr cp_group_; std::unique_ptr moe_tp_group_; std::unique_ptr moe_ep_group_; }; diff --git a/xllm/core/framework/parallel_state/parallel_args.h b/xllm/core/framework/parallel_state/parallel_args.h index 3c323c2aa..1e3cefaa7 100644 --- a/xllm/core/framework/parallel_state/parallel_args.h +++ b/xllm/core/framework/parallel_state/parallel_args.h @@ -128,8 +128,6 @@ struct ParallelArgs { // the same rank set during prefill, but it remains a separate handle so the // SP communication policy can evolve independently from TP. ProcessGroup* sp_group_ = nullptr; - // Context-parallel communication group for CP-specific runtime collectives. - ProcessGroup* cp_group_ = nullptr; ProcessGroup* moe_ep_group_ = nullptr; ProcessGroup* moe_tp_group_ = nullptr; }; diff --git a/xllm/core/layers/npu/npu_lm_head_impl.cpp b/xllm/core/layers/npu/npu_lm_head_impl.cpp index 61c540e78..2bfd6d4fd 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.cpp +++ b/xllm/core/layers/npu/npu_lm_head_impl.cpp @@ -27,6 +27,7 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param, const ModelArgs& args, const ParallelArgs& parallel_args, bool isPrefill) { + param.outputHidden = cp_size_ > 1; param.unpadInputs = true; param.gatherAhead = isPrefill; param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads(); @@ -100,7 +101,7 @@ NpuLmHeadImpl::NpuLmHeadImpl(const ModelContext& context) : BaseLayer(context) { false); atb_weight_tensors_.resize(1); - atOutTensors_.resize(1); + atOutTensors_.resize(2); auto options = context.get_tensor_options(); dtype_ = c10::typeMetaToScalarType(options.dtype()); @@ -145,14 +146,14 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node, return -1; } node.inTensors.resize(node.operation->GetInputNum()); - node.outTensors.resize(1); + node.outTensors.resize(node.operation->GetOutputNum()); node.inTensors.at(1) = &atb_weight_tensors_[0]; node.variantPack.inTensors.reserve(node.inTensors.size()); node.variantPack.inTensors.resize(node.inTensors.size()); - node.variantPack.outTensors.reserve(1); - node.variantPack.outTensors.resize(1); + node.variantPack.outTensors.reserve(node.outTensors.size()); + node.variantPack.outTensors.resize(node.outTensors.size()); return atb::NO_ERROR; } @@ -160,6 +161,15 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node, torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states, const torch::Tensor& seleted_idxes, int nodeId) { + torch::Tensor out_hidden; + return forward_with_hidden(hidden_states, seleted_idxes, out_hidden, nodeId); +} + +torch::Tensor NpuLmHeadImpl::forward_with_hidden( + const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden, + int nodeId) { atb::Status st; build_node_variant_pack(lm_head_node_prefill_, hidden_states, seleted_idxes); st = execute_node(lm_head_node_prefill_, nodeId); @@ -169,6 +179,9 @@ torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states, if (padded_vocab_size_ > vocab_size_ && vocab_size_ > 0) { output = output.slice(/*dim=*/-1, /*start=*/0, /*end=*/vocab_size_); } + if (atOutTensors_.size() > 1) { + out_hidden = atOutTensors_[1]; + } return output; } @@ -215,12 +228,16 @@ void NpuLmHeadImpl::build_node_variant_pack( inTensorDescs.at(8) = placeholder_.desc; atb::Status st = node.operation->InferShape(inTensorDescs, outTensorDescs); - at::Tensor newTensor = - atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(0)); - - atOutTensors_.at(0) = newTensor; - node.variantPack.outTensors.at(0) = - atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(0)); + LOG_IF(FATAL, st != atb::NO_ERROR) + << model_name_ << " infer lmhead shape fail, error code: " << st; + + atOutTensors_.resize(node.variantPack.outTensors.size()); + for (size_t i = 0; i < node.variantPack.outTensors.size(); ++i) { + atOutTensors_.at(i) = + atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(i)); + node.variantPack.outTensors.at(i) = + atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(i)); + } } } // namespace layer diff --git a/xllm/core/layers/npu/npu_lm_head_impl.h b/xllm/core/layers/npu/npu_lm_head_impl.h index bd3246f67..711f2d7f3 100644 --- a/xllm/core/layers/npu/npu_lm_head_impl.h +++ b/xllm/core/layers/npu/npu_lm_head_impl.h @@ -53,6 +53,11 @@ class NpuLmHeadImpl : public BaseLayer { const torch::Tensor& seleted_idxes, int nodeId); + torch::Tensor forward_with_hidden(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden, + int nodeId); + private: void param_from_args(atb_speed::common::LmHeadParam& param, const ModelArgs& args, diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 899729fd5..01e3b31a3 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -31,7 +31,6 @@ limitations under the License. #include "core/common/global_flags.h" #include "framework/kv_cache/kv_cache.h" #include "framework/model/model_input_params.h" -#include "framework/parallel_state/parallel_state.h" #include "framework/state_dict/state_dict.h" #if defined(USE_CUDA) || defined(USE_ILU) || defined(USE_MUSA) #include "layers/cuda/flashinfer_workspace.h" @@ -125,9 +124,16 @@ std::optional LLMWorkerImpl::step_internal( } torch::Tensor logits; + torch::Tensor selected_hidden_from_lm_head; if (sampling_params.selected_token_idxes.defined()) { - logits = model_->logits(model_output.hidden_states, - sampling_params.selected_token_idxes); + if (options_.cp_size() > 1) { + logits = model_->logits(model_output.hidden_states, + sampling_params.selected_token_idxes, + selected_hidden_from_lm_head); + } else { + logits = model_->logits(model_output.hidden_states, + sampling_params.selected_token_idxes); + } } ForwardOutput output; @@ -199,16 +205,15 @@ std::optional LLMWorkerImpl::step_internal( if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) { output.sample_output.embeddings = embeddings; } else if (sampling_params.selected_token_idxes.defined()) { - if (context_.get_parallel_args().cp_size() > 1) { - auto* cp_group = context_.get_parallel_args().cp_group_; - CHECK(cp_group != nullptr) - << "cp_group_ must be initialized for CP+MTP embeddings gather."; - CHECK_EQ(cp_group->world_size(), context_.get_parallel_args().cp_size()) - << "cp_group_ world size mismatch with cp_size."; - embeddings = parallel_state::gather(embeddings, cp_group, /*dim=*/0); - } - output.sample_output.embeddings = embeddings.index_select( + if (options_.cp_size() > 1) { + CHECK(selected_hidden_from_lm_head.defined()) + << "selected_hidden_from_lm_head must be defined when " + "selected_token_idxes is defined."; + output.sample_output.embeddings = selected_hidden_from_lm_head; + } else { + output.sample_output.embeddings = embeddings.index_select( /*dim=*/0, sampling_params.selected_token_idxes); + } } } diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index fa2558052..6d88b3c45 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -323,7 +323,6 @@ std::optional MTPWorkerImpl::step_prefill( auto mask = (token_ids == -1); token_ids.masked_scatter_(mask, next_tokens); } - // generate kv cache for draft model timer.reset(); auto draft_future = draft_impl_->step_async(prefill_input); @@ -360,7 +359,8 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, } auto& extra_token_ids = input_params.extra_token_ids; - CHECK_EQ(extra_token_ids.size(), static_cast(input_params.num_sequences)) + CHECK_EQ(extra_token_ids.size(), + static_cast(input_params.num_sequences)) << "extra_token_ids size should match num_sequences"; torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); diff --git a/xllm/models/llm/npu/llm_model_base.h b/xllm/models/llm/npu/llm_model_base.h index 6cd8fd999..22189c394 100644 --- a/xllm/models/llm/npu/llm_model_base.h +++ b/xllm/models/llm/npu/llm_model_base.h @@ -420,6 +420,17 @@ class LlmForCausalLMImplBase : public torch::nn::Module { return npu_lm_head_(hidden_states, seleted_idxes, 0); } + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // out_hidden: [num_seqs, hidden_size] + // returns: [num_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) { + return npu_lm_head_->forward_with_hidden( + hidden_states, seleted_idxes, out_hidden, 0); + } + // hidden_states: [num_tokens, hidden_size] // seleted_idxes: [num_tokens] // returns: [num_seqs, hidden_size] diff --git a/xllm/models/llm/npu/mtp_model_base.h b/xllm/models/llm/npu/mtp_model_base.h index 07f9ffccc..d52fe4ba0 100644 --- a/xllm/models/llm/npu/mtp_model_base.h +++ b/xllm/models/llm/npu/mtp_model_base.h @@ -294,6 +294,17 @@ class MtpForCausalLMImplBase : public torch::nn::Module { return lm_head_(hidden_states, seleted_idxes, 0); } + // hidden_states: [num_tokens, hidden_size] + // seleted_idxes: [num_tokens] + // out_hidden: [num_seqs, hidden_size] + // returns: [num_tokens, vocab_size] + virtual torch::Tensor logits(const torch::Tensor& hidden_states, + const torch::Tensor& seleted_idxes, + torch::Tensor& out_hidden) { + return lm_head_->forward_with_hidden( + hidden_states, seleted_idxes, out_hidden, 0); + } + // hidden_states: [num_tokens, hidden_size] // seleted_idxes: [num_tokens] // returns: [num_seqs, hidden_size] From 05613219b8ed52af3ec6e6ee1bc17b281b3bfc9a Mon Sep 17 00:00:00 2001 From: shifengmin Date: Fri, 3 Apr 2026 14:01:13 +0800 Subject: [PATCH 5/6] apply clang format --- xllm/core/framework/batch/batch_input_builder.cpp | 7 +++---- xllm/core/framework/model/model_traits.h | 12 ++++++------ .../parallel_state/collective_communicator.cpp | 3 +-- xllm/core/runtime/llm_worker_impl.cpp | 4 ++-- xllm/core/runtime/mtp_worker_impl.cpp | 5 ----- xllm/core/runtime/params_utils.cpp | 6 +++--- xllm/core/runtime/worker_impl.cpp | 9 +++++---- 7 files changed, 20 insertions(+), 26 deletions(-) diff --git a/xllm/core/framework/batch/batch_input_builder.cpp b/xllm/core/framework/batch/batch_input_builder.cpp index de228bb33..65ace2266 100644 --- a/xllm/core/framework/batch/batch_input_builder.cpp +++ b/xllm/core/framework/batch/batch_input_builder.cpp @@ -423,10 +423,9 @@ void BatchInputBuilder::extract_tokens_and_positions(Sequence* sequence, state.mtp_shifted_token_ids.emplace_back(extra_token_id); if (padded_seq_len > seq_len) { const int32_t pad_token_id = args_ ? args_->pad_token_id() : 0; - state.mtp_shifted_token_ids.insert( - state.mtp_shifted_token_ids.end(), - padded_seq_len - seq_len, - pad_token_id); + state.mtp_shifted_token_ids.insert(state.mtp_shifted_token_ids.end(), + padded_seq_len - seq_len, + pad_token_id); } } } diff --git a/xllm/core/framework/model/model_traits.h b/xllm/core/framework/model/model_traits.h index f0a2571ff..d36b75da2 100644 --- a/xllm/core/framework/model/model_traits.h +++ b/xllm/core/framework/model/model_traits.h @@ -71,12 +71,12 @@ template struct has_logits_with_hidden : std::false_type {}; template -struct has_logits_with_hidden< - T, - std::void_t()->logits( - std::declval(), - std::declval(), - std::declval()))>> : std::true_type {}; +struct has_logits_with_hidden()->logits( + std::declval(), + std::declval(), + std::declval()))>> + : std::true_type {}; template struct has_lazy_load_model : std::false_type {}; diff --git a/xllm/core/framework/parallel_state/collective_communicator.cpp b/xllm/core/framework/parallel_state/collective_communicator.cpp index aa97aa7ce..0fa4c50fc 100644 --- a/xllm/core/framework/parallel_state/collective_communicator.cpp +++ b/xllm/core/framework/parallel_state/collective_communicator.cpp @@ -119,13 +119,12 @@ CollectiveCommunicator::CollectiveCommunicator(int global_rank, void CollectiveCommunicator::create_process_groups( const std::string& master_addr, const torch::Device& device) { - int global_rank = parallel_args_->rank(); int world_size = parallel_args_->world_size(); int dp_size = parallel_args_->dp_size(); int ep_size = parallel_args_->ep_size(); int cp_size = parallel_args_->cp_size(); - + std::string host; int port; net::parse_host_port_from_addr(master_addr, host, port); diff --git a/xllm/core/runtime/llm_worker_impl.cpp b/xllm/core/runtime/llm_worker_impl.cpp index 01e3b31a3..0a9531366 100644 --- a/xllm/core/runtime/llm_worker_impl.cpp +++ b/xllm/core/runtime/llm_worker_impl.cpp @@ -208,11 +208,11 @@ std::optional LLMWorkerImpl::step_internal( if (options_.cp_size() > 1) { CHECK(selected_hidden_from_lm_head.defined()) << "selected_hidden_from_lm_head must be defined when " - "selected_token_idxes is defined."; + "selected_token_idxes is defined."; output.sample_output.embeddings = selected_hidden_from_lm_head; } else { output.sample_output.embeddings = embeddings.index_select( - /*dim=*/0, sampling_params.selected_token_idxes); + /*dim=*/0, sampling_params.selected_token_idxes); } } } diff --git a/xllm/core/runtime/mtp_worker_impl.cpp b/xllm/core/runtime/mtp_worker_impl.cpp index 6d88b3c45..d42002d74 100644 --- a/xllm/core/runtime/mtp_worker_impl.cpp +++ b/xllm/core/runtime/mtp_worker_impl.cpp @@ -359,9 +359,6 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, } auto& extra_token_ids = input_params.extra_token_ids; - CHECK_EQ(extra_token_ids.size(), - static_cast(input_params.num_sequences)) - << "extra_token_ids size should match num_sequences"; torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU); Slice tokens_ids_slice = { @@ -381,8 +378,6 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input, tokens_ids_slice_i.end()); new_token_ids.emplace_back(extra_token_ids[i]); } - CHECK_EQ(static_cast(start_idx), tokens_ids_slice.size()) - << "q_seq_lens sum should equal token_ids numel"; prefill_input.token_ids = torch::tensor(new_token_ids, prefill_input.positions.options()); } diff --git a/xllm/core/runtime/params_utils.cpp b/xllm/core/runtime/params_utils.cpp index 947d864b4..6c7bfbd97 100644 --- a/xllm/core/runtime/params_utils.cpp +++ b/xllm/core/runtime/params_utils.cpp @@ -131,9 +131,9 @@ void proto_to_forward_input(const proto::ForwardInput* pb_forward_input, std::vector extra_token_ids = std::vector(pb_forward_input->extra_token_ids().begin(), pb_forward_input->extra_token_ids().end()); - std::vector mtp_shifted_token_ids = std::vector( - pb_forward_input->mtp_shifted_token_ids().begin(), - pb_forward_input->mtp_shifted_token_ids().end()); + std::vector mtp_shifted_token_ids = + std::vector(pb_forward_input->mtp_shifted_token_ids().begin(), + pb_forward_input->mtp_shifted_token_ids().end()); std::vector request_ids = std::vector(pb_forward_input->request_ids().begin(), pb_forward_input->request_ids().end()); diff --git a/xllm/core/runtime/worker_impl.cpp b/xllm/core/runtime/worker_impl.cpp index 5f4d2a6a8..b0f3c9603 100644 --- a/xllm/core/runtime/worker_impl.cpp +++ b/xllm/core/runtime/worker_impl.cpp @@ -557,10 +557,11 @@ void WorkerImpl::prepare_work_before_execute(const ForwardInput& input, CpPrefillInputs tmp_cp_inputs; if (parallel_args_.cp_size() > 1 && input.input_params.batch_forward_type.is_prefill()) { - tmp_cp_inputs = prepare_cp_prefill_inputs(parallel_args_.cp_size(), - input.token_ids.cpu(), - input.positions.cpu(), - input.input_params.q_seq_lens.cpu()); + tmp_cp_inputs = + prepare_cp_prefill_inputs(parallel_args_.cp_size(), + input.token_ids.cpu(), + input.positions.cpu(), + input.input_params.q_seq_lens.cpu()); processed_input.input_params.cp_prefill_inputs = tmp_cp_inputs.to(device_); CpEpPadding cp_ep_padding( input.token_ids.cpu(), From eeaea88fd54169f5692b889800960f1cf3d07dbd Mon Sep 17 00:00:00 2001 From: shifengmin Date: Fri, 3 Apr 2026 18:28:24 +0800 Subject: [PATCH 6/6] update xllm_atb_layers commit id --- third_party/xllm_atb_layers | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/third_party/xllm_atb_layers b/third_party/xllm_atb_layers index 918c03d2a..29ab875c6 160000 --- a/third_party/xllm_atb_layers +++ b/third_party/xllm_atb_layers @@ -1 +1 @@ -Subproject commit 918c03d2abc4c9996196a797aefe743863b7e0ae +Subproject commit 29ab875c61f1e70f925d4472851d57c964667e55