Skip to content

Commit 41d4bc3

Browse files
committed
bugix: cp+mtp get hidden from lmhead
1 parent 4343993 commit 41d4bc3

File tree

11 files changed

+105
-46
lines changed

11 files changed

+105
-46
lines changed

xllm/core/framework/model/causal_lm.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,17 @@ class CausalLM : public torch::nn::Module {
7474
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
7575
const torch::Tensor& seleted_idxes) = 0;
7676

77+
// hidden_states: [num_tokens, hidden_size]
78+
// seleted_idxes: [num_tokens]
79+
// out_hidden: [num_selected_tokens, hidden_size]
80+
// returns: [num_selected_tokens, vocab_size]
81+
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
82+
const torch::Tensor& seleted_idxes,
83+
torch::Tensor& out_hidden) {
84+
NOT_IMPLEMENTED();
85+
return torch::Tensor();
86+
}
87+
7788
virtual void load_model(std::unique_ptr<ModelLoader> loader) = 0;
7889

7990
virtual torch::Device device() const = 0;
@@ -158,6 +169,16 @@ class CausalLMImpl : public CausalLM {
158169
return model_->logits(hidden_states, seleted_idxes);
159170
}
160171

172+
torch::Tensor logits(const torch::Tensor& hidden_states,
173+
const torch::Tensor& seleted_idxes,
174+
torch::Tensor& out_hidden) override {
175+
if constexpr (detail::has_logits_with_hidden<Model>::value) {
176+
return model_->logits(hidden_states, seleted_idxes, out_hidden);
177+
} else {
178+
return CausalLM::logits(hidden_states, seleted_idxes, out_hidden);
179+
}
180+
}
181+
161182
void load_model(std::unique_ptr<ModelLoader> loader) override {
162183
model_->load_model(std::move(loader));
163184
}

xllm/core/framework/model/model_traits.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,17 @@ struct has_set_word_embedding<
6767
std::void_t<decltype(std::declval<T>()->set_word_embedding(
6868
std::declval<layer::WordEmbedding&>()))>> : std::true_type {};
6969

70+
template <typename T, typename = void>
71+
struct has_logits_with_hidden : std::false_type {};
72+
73+
template <typename T>
74+
struct has_logits_with_hidden<
75+
T,
76+
std::void_t<decltype(std::declval<T>()->logits(
77+
std::declval<const torch::Tensor&>(),
78+
std::declval<const torch::Tensor&>(),
79+
std::declval<torch::Tensor&>()))>> : std::true_type {};
80+
7081
template <typename T, typename = void>
7182
struct has_lazy_load_model : std::false_type {};
7283

xllm/core/framework/parallel_state/collective_communicator.cpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -131,25 +131,6 @@ void CollectiveCommunicator::create_process_groups(
131131
net::parse_host_port_from_addr(master_addr, host, port);
132132

133133
#if defined(USE_NPU)
134-
// ATB backend usually relies on ATB/HCCL communication. However CP+MTP
135-
// needs torch process group collectives in worker runtime.
136-
if (cp_size > 1) {
137-
CHECK_EQ(world_size % cp_size, 0)
138-
<< "world_size must be divisible by cp_size.";
139-
const int cp_group_count = world_size / cp_size;
140-
int port_offset = global_rank % cp_group_count + 1;
141-
cp_group_ = create_process_group(global_rank,
142-
world_size,
143-
cp_size,
144-
port + port_offset,
145-
true,
146-
host,
147-
"cp_group",
148-
device);
149-
parallel_args_->cp_group_ = cp_group_.get();
150-
port += cp_group_count;
151-
}
152-
153134
if (FLAGS_npu_kernel_backend == "ATB") {
154135
return;
155136
}

xllm/core/framework/parallel_state/collective_communicator.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,7 +48,6 @@ class CollectiveCommunicator {
4848
// aliases ParallelArgs::sp_group_ to tp_group_ instead of constructing a
4949
// separate communicator, so this stays empty for now.
5050
std::unique_ptr<ProcessGroup> sp_group_;
51-
std::unique_ptr<ProcessGroup> cp_group_;
5251
std::unique_ptr<ProcessGroup> moe_tp_group_;
5352
std::unique_ptr<ProcessGroup> moe_ep_group_;
5453
};

xllm/core/framework/parallel_state/parallel_args.h

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -128,8 +128,6 @@ struct ParallelArgs {
128128
// the same rank set during prefill, but it remains a separate handle so the
129129
// SP communication policy can evolve independently from TP.
130130
ProcessGroup* sp_group_ = nullptr;
131-
// Context-parallel communication group for CP-specific runtime collectives.
132-
ProcessGroup* cp_group_ = nullptr;
133131
ProcessGroup* moe_ep_group_ = nullptr;
134132
ProcessGroup* moe_tp_group_ = nullptr;
135133
};

xllm/core/layers/npu/npu_lm_head_impl.cpp

Lines changed: 27 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@ void NpuLmHeadImpl::param_from_args(atb_speed::common::LmHeadParam& param,
2828
const ParallelArgs& parallel_args,
2929
bool isPrefill) {
3030
const bool use_column_parallel = cp_size_ > 1;
31+
param.outputHidden = cp_size_ > 1;
3132
param.unpadInputs = true;
3233
param.gatherAhead = isPrefill;
3334
param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads();
@@ -101,7 +102,7 @@ NpuLmHeadImpl::NpuLmHeadImpl(const ModelContext& context) : BaseLayer(context) {
101102
false);
102103

103104
atb_weight_tensors_.resize(1);
104-
atOutTensors_.resize(1);
105+
atOutTensors_.resize(2);
105106

106107
auto options = context.get_tensor_options();
107108
dtype_ = c10::typeMetaToScalarType(options.dtype());
@@ -146,26 +147,38 @@ int64_t NpuLmHeadImpl::init_node(atb_speed::Model::Node& node,
146147
return -1;
147148
}
148149
node.inTensors.resize(node.operation->GetInputNum());
149-
node.outTensors.resize(1);
150+
node.outTensors.resize(node.operation->GetOutputNum());
150151

151152
node.inTensors.at(1) = &atb_weight_tensors_[0];
152153

153154
node.variantPack.inTensors.reserve(node.inTensors.size());
154155
node.variantPack.inTensors.resize(node.inTensors.size());
155-
node.variantPack.outTensors.reserve(1);
156-
node.variantPack.outTensors.resize(1);
156+
node.variantPack.outTensors.reserve(node.outTensors.size());
157+
node.variantPack.outTensors.resize(node.outTensors.size());
157158

158159
return atb::NO_ERROR;
159160
}
160161

161162
torch::Tensor NpuLmHeadImpl::forward(const torch::Tensor& hidden_states,
162163
const torch::Tensor& seleted_idxes,
163164
int nodeId) {
165+
torch::Tensor out_hidden;
166+
return forward_with_hidden(hidden_states, seleted_idxes, out_hidden, nodeId);
167+
}
168+
169+
torch::Tensor NpuLmHeadImpl::forward_with_hidden(
170+
const torch::Tensor& hidden_states,
171+
const torch::Tensor& seleted_idxes,
172+
torch::Tensor& out_hidden,
173+
int nodeId) {
164174
atb::Status st;
165175
build_node_variant_pack(lm_head_node_prefill_, hidden_states, seleted_idxes);
166176
st = execute_node(lm_head_node_prefill_, nodeId);
167177
LOG_IF(FATAL, st != 0) << model_name_
168178
<< "execute lmhead node fail, error code: " << st;
179+
if (atOutTensors_.size() > 1) {
180+
out_hidden = atOutTensors_[1];
181+
}
169182
return atOutTensors_[0];
170183
}
171184

@@ -212,12 +225,16 @@ void NpuLmHeadImpl::build_node_variant_pack(
212225
inTensorDescs.at(8) = placeholder_.desc;
213226

214227
atb::Status st = node.operation->InferShape(inTensorDescs, outTensorDescs);
215-
at::Tensor newTensor =
216-
atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(0));
217-
218-
atOutTensors_.at(0) = newTensor;
219-
node.variantPack.outTensors.at(0) =
220-
atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(0));
228+
LOG_IF(FATAL, st != atb::NO_ERROR)
229+
<< model_name_ << " infer lmhead shape fail, error code: " << st;
230+
231+
atOutTensors_.resize(node.variantPack.outTensors.size());
232+
for (size_t i = 0; i < node.variantPack.outTensors.size(); ++i) {
233+
atOutTensors_.at(i) =
234+
atb_speed::Utils::CreateAtTensorFromTensorDesc(outTensorDescs.at(i));
235+
node.variantPack.outTensors.at(i) =
236+
atb_speed::Utils::AtTensor2Tensor(atOutTensors_.at(i));
237+
}
221238
}
222239

223240
} // namespace layer

xllm/core/layers/npu/npu_lm_head_impl.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,11 @@ class NpuLmHeadImpl : public BaseLayer {
5353
const torch::Tensor& seleted_idxes,
5454
int nodeId);
5555

56+
torch::Tensor forward_with_hidden(const torch::Tensor& hidden_states,
57+
const torch::Tensor& seleted_idxes,
58+
torch::Tensor& out_hidden,
59+
int nodeId);
60+
5661
private:
5762
void param_from_args(atb_speed::common::LmHeadParam& param,
5863
const ModelArgs& args,

xllm/core/runtime/llm_worker_impl.cpp

Lines changed: 17 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,6 @@ limitations under the License.
3131
#include "core/common/global_flags.h"
3232
#include "framework/kv_cache/kv_cache.h"
3333
#include "framework/model/model_input_params.h"
34-
#include "framework/parallel_state/parallel_state.h"
3534
#include "framework/state_dict/state_dict.h"
3635
#if defined(USE_CUDA) || defined(USE_ILU) || defined(USE_MUSA)
3736
#include "layers/cuda/flashinfer_workspace.h"
@@ -125,9 +124,16 @@ std::optional<ForwardOutput> LLMWorkerImpl::step_internal(
125124
}
126125

127126
torch::Tensor logits;
127+
torch::Tensor selected_hidden_from_lm_head;
128128
if (sampling_params.selected_token_idxes.defined()) {
129-
logits = model_->logits(model_output.hidden_states,
130-
sampling_params.selected_token_idxes);
129+
if (options_.cp_size() > 1) {
130+
logits = model_->logits(model_output.hidden_states,
131+
sampling_params.selected_token_idxes,
132+
selected_hidden_from_lm_head);
133+
} else {
134+
logits = model_->logits(model_output.hidden_states,
135+
sampling_params.selected_token_idxes);
136+
}
131137
}
132138

133139
ForwardOutput output;
@@ -199,16 +205,15 @@ std::optional<ForwardOutput> LLMWorkerImpl::step_internal(
199205
if (!input.input_params.batch_forward_type.is_decode() && !is_spec_draft_) {
200206
output.sample_output.embeddings = embeddings;
201207
} else if (sampling_params.selected_token_idxes.defined()) {
202-
if (context_.get_parallel_args().cp_size() > 1) {
203-
auto* cp_group = context_.get_parallel_args().cp_group_;
204-
CHECK(cp_group != nullptr)
205-
<< "cp_group_ must be initialized for CP+MTP embeddings gather.";
206-
CHECK_EQ(cp_group->world_size(), context_.get_parallel_args().cp_size())
207-
<< "cp_group_ world size mismatch with cp_size.";
208-
embeddings = parallel_state::gather(embeddings, cp_group, /*dim=*/0);
209-
}
210-
output.sample_output.embeddings = embeddings.index_select(
208+
if (options_.cp_size() > 1) {
209+
CHECK(selected_hidden_from_lm_head.defined())
210+
<< "selected_hidden_from_lm_head must be defined when "
211+
"selected_token_idxes is defined.";
212+
output.sample_output.embeddings = selected_hidden_from_lm_head;
213+
} else {
214+
output.sample_output.embeddings = embeddings.index_select(
211215
/*dim=*/0, sampling_params.selected_token_idxes);
216+
}
212217
}
213218
}
214219

xllm/core/runtime/mtp_worker_impl.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -323,7 +323,6 @@ std::optional<ForwardOutput> MTPWorkerImpl::step_prefill(
323323
auto mask = (token_ids == -1);
324324
token_ids.masked_scatter_(mask, next_tokens);
325325
}
326-
327326
// generate kv cache for draft model
328327
timer.reset();
329328
auto draft_future = draft_impl_->step_async(prefill_input);
@@ -360,7 +359,8 @@ void MTPWorkerImpl::prepare_prefill_inputs(const ForwardInput& input,
360359
}
361360

362361
auto& extra_token_ids = input_params.extra_token_ids;
363-
CHECK_EQ(extra_token_ids.size(), static_cast<size_t>(input_params.num_sequences))
362+
CHECK_EQ(extra_token_ids.size(),
363+
static_cast<size_t>(input_params.num_sequences))
364364
<< "extra_token_ids size should match num_sequences";
365365

366366
torch::Tensor token_ids = safe_to(input.token_ids, torch::kCPU);

xllm/models/llm/npu/llm_model_base.h

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -420,6 +420,17 @@ class LlmForCausalLMImplBase : public torch::nn::Module {
420420
return npu_lm_head_(hidden_states, seleted_idxes, 0);
421421
}
422422

423+
// hidden_states: [num_tokens, hidden_size]
424+
// seleted_idxes: [num_tokens]
425+
// out_hidden: [num_seqs, hidden_size]
426+
// returns: [num_tokens, vocab_size]
427+
virtual torch::Tensor logits(const torch::Tensor& hidden_states,
428+
const torch::Tensor& seleted_idxes,
429+
torch::Tensor& out_hidden) {
430+
return npu_lm_head_->forward_with_hidden(
431+
hidden_states, seleted_idxes, out_hidden, 0);
432+
}
433+
423434
// hidden_states: [num_tokens, hidden_size]
424435
// seleted_idxes: [num_tokens]
425436
// returns: [num_seqs, hidden_size]

0 commit comments

Comments
 (0)