Skip to content

Commit 912e8cb

Browse files
committed
feat: support executor_impl_factory, dp_is_decode and remove mrope interface.
1 parent 6a9ac5b commit 912e8cb

29 files changed

Lines changed: 238 additions & 267 deletions

xllm/core/common/global_flags.cpp

Lines changed: 0 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -84,21 +84,11 @@ DEFINE_bool(enable_customize_mla_kernel, false, "enable customize mla kernel");
8484

8585
// --- graph mode execution config ---
8686

87-
DEFINE_bool(enable_acl_graph,
88-
false,
89-
"Whether to enable ACL graph execution for decode phase.");
90-
9187
DEFINE_int32(max_seq_len_for_graph_mode,
9288
0,
9389
"Maximum number of tokens per sequence for graph execution. "
9490
"If 0, use model max_position_embeddings.");
9591

96-
DEFINE_bool(enable_acl_graph_no_padding,
97-
false,
98-
"Whether to enable ACL graph execution for decode phase without "
99-
"padding. If true, graph will be caputured with every actual num "
100-
"tokens, as stride is 1.");
101-
10292
DEFINE_bool(enable_graph,
10393
false,
10494
"Whether to enable graph execution for decode phase.");

xllm/core/common/global_flags.h

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -83,12 +83,8 @@ DECLARE_string(rank_tablefile);
8383

8484
DECLARE_bool(enable_mla);
8585

86-
DECLARE_bool(enable_acl_graph);
87-
8886
DECLARE_int32(max_seq_len_for_graph_mode);
8987

90-
DECLARE_bool(enable_acl_graph_no_padding);
91-
9288
DECLARE_bool(enable_graph);
9389

9490
DECLARE_bool(enable_graph_no_padding);

xllm/core/distributed_runtime/llm_engine.cpp

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -858,8 +858,8 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
858858
std::vector<RawForwardInput> batched_inputs;
859859
batched_inputs.reserve(dp_size_);
860860
// some dp related variables
861-
std::vector<int32_t> dp_global_token_nums;
862-
dp_global_token_nums.resize(dp_size_);
861+
std::vector<int32_t> dp_global_token_nums(dp_size_);
862+
std::vector<int32_t> dp_is_decode(dp_size_, 0);
863863
bool global_empty_kv_cache = true;
864864
// when enable dp, we need to check the forward type of each batch
865865
// and set the empty forward type of each batch to the same value as the first
@@ -878,6 +878,8 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
878878
!batched_inputs[dp_rank].batch_forward_type.is_empty()) {
879879
batch_forward_type = batched_inputs[dp_rank].batch_forward_type;
880880
}
881+
dp_is_decode[dp_rank] = batch_forward_type.is_decode() &&
882+
batched_inputs[dp_rank].q_max_seq_len == 1;
881883
}
882884

883885
// eplb related
@@ -889,6 +891,7 @@ std::vector<RawForwardInput> LLMEngine::prepare_inputs(
889891
// update dp_global_token_nums and global_empty_kv_cache
890892
for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
891893
batched_inputs[dp_rank].dp_global_token_nums = dp_global_token_nums;
894+
batched_inputs[dp_rank].dp_is_decode = dp_is_decode;
892895
batched_inputs[dp_rank].global_empty_kv_cache = global_empty_kv_cache;
893896
if (FLAGS_enable_eplb) {
894897
batched_inputs[dp_rank].eplb_info = eplb_info;

xllm/core/distributed_runtime/vlm_engine.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -434,8 +434,8 @@ std::vector<RawForwardInput> VLMEngine::prepare_inputs(
434434
std::vector<RawForwardInput> batched_inputs;
435435
batched_inputs.reserve(dp_size_);
436436
// some dp related variables
437-
std::vector<int32_t> dp_global_token_nums;
438-
dp_global_token_nums.resize(dp_size_);
437+
std::vector<int32_t> dp_global_token_nums(dp_size_);
438+
std::vector<int32_t> dp_is_decode(dp_size_, 0);
439439
bool global_empty_kv_cache = true;
440440

441441
for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
@@ -445,12 +445,16 @@ std::vector<RawForwardInput> VLMEngine::prepare_inputs(
445445
batched_inputs[dp_rank].flatten_tokens_vec.size();
446446
global_empty_kv_cache =
447447
batched_inputs[dp_rank].empty_kv_cache && global_empty_kv_cache;
448+
dp_is_decode[dp_rank] =
449+
batched_inputs[dp_rank].batch_forward_type.is_decode() &&
450+
batched_inputs[dp_rank].q_max_seq_len == 1;
448451
}
449452

450453
// update dp_global_token_nums and global_empty_kv_cache
451454
for (auto dp_rank = 0; dp_rank < dp_size_; ++dp_rank) {
452455
batched_inputs[dp_rank].dp_global_token_nums = dp_global_token_nums;
453456
batched_inputs[dp_rank].global_empty_kv_cache = global_empty_kv_cache;
457+
batched_inputs[dp_rank].dp_is_decode = std::move(dp_is_decode);
454458
}
455459

456460
return batched_inputs;

xllm/core/framework/model/causal_lm.h

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -119,10 +119,6 @@ class CausalLM : public torch::nn::Module {
119119
LOG(FATAL) << "Method 'set_word_embedding' is not implemented/supported by "
120120
"this model.";
121121
}
122-
virtual void skip_mrope() {}
123-
virtual void apply_mrope(const torch::Tensor positions,
124-
torch::Tensor& cos_pos,
125-
torch::Tensor& sin_pos) {}
126122
};
127123

128124
template <typename Model>
@@ -192,20 +188,6 @@ class CausalLMImpl : public CausalLM {
192188

193189
const torch::TensorOptions& options() const override { return options_; }
194190

195-
void skip_mrope() override {
196-
#if defined(USE_MLU)
197-
model_->skip_mrope();
198-
#endif
199-
}
200-
201-
void apply_mrope(const torch::Tensor positions,
202-
torch::Tensor& cos_pos,
203-
torch::Tensor& sin_pos) override {
204-
#if defined(USE_MLU)
205-
model_->apply_mrope(positions, cos_pos, sin_pos);
206-
#endif
207-
}
208-
209191
private:
210192
Model model_;
211193

xllm/core/framework/model/model_input_params.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -127,6 +127,7 @@ struct ModelInputParams {
127127

128128
params.mm_data = mm_data.to(device);
129129
params.dp_global_token_nums = dp_global_token_nums;
130+
params.dp_is_decode = dp_is_decode;
130131
params.embedding_ids = std::move(embedding_ids);
131132
params.extra_token_ids = std::move(extra_token_ids);
132133
params.dp_ep_padding_data = dp_ep_padding_data;
@@ -178,7 +179,7 @@ struct ModelInputParams {
178179
print_tensor(new_cache_slots, "ModelInputParams: new_cache_slots", 4);
179180
print_tensor(block_tables, "ModelInputParams: block_tables", 4);
180181
LOG(INFO) << "ModelInputParams: dp_global_token_nums is "
181-
<< dp_global_token_nums;
182+
<< dp_global_token_nums << ", dp_is_decode: " << dp_is_decode;
182183
}
183184

184185
int32_t get_q_seq_len(int32_t seq_idx) const {
@@ -240,6 +241,7 @@ struct ModelInputParams {
240241

241242
// num tokens of all workers,mainly used for dp case
242243
std::vector<int32_t> dp_global_token_nums;
244+
std::vector<int32_t> dp_is_decode;
243245
// whether the kv-cache is empty for all sequences,mainly used for dp case
244246
bool global_empty_kv_cache = true;
245247

xllm/core/layers/npu/npu_base_layer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,11 +78,11 @@ atb::Status BaseLayer::execute_node(atb_speed::Model::Node& node,
7878
// However, libtorch_npu current stream is set to default stream after
7979
// capture ends, causing inconsistency between ATB context and the actual
8080
// execution stream
81-
if (FLAGS_enable_acl_graph) {
81+
if (FLAGS_enable_graph) {
8282
void* stream = c10_npu::getCurrentNPUStream(device_.index()).stream();
8383
context_->SetExecuteStream(stream);
8484
}
85-
// if (FLAGS_enable_acl_graph && !graph_captured_) {
85+
// if (FLAGS_enable_graph && !graph_captured_) {
8686
// void* stream = c10_npu::getCurrentNPUStream(device_.index()).stream();
8787
// aclmdlRICaptureStatus status;
8888
// aclmdlRI modelRI;

xllm/core/layers/npu/npu_glm4_moe_decoder_layer.cpp

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,7 +121,7 @@ void Glm4MoeDecoderImpl::initialize_basic_parameters(
121121

122122
// not support MTP model yet
123123
param.enableAclGraph =
124-
FLAGS_enable_acl_graph && !is_prefill && args.n_layers() > 1;
124+
FLAGS_enable_graph && !is_prefill && args.n_layers() > 1;
125125

126126
param.moeLinearTransposeType = (layer_id_ < args.first_k_dense_replace())
127127
? std::vector<int>{-1, -1, -1, -1}
@@ -460,7 +460,7 @@ void Glm4MoeDecoderImpl::build_node_variant_pack(
460460
node.variantPack.inTensors.at(input_idx++) =
461461
atb_speed::Utils::AtTensor2Tensor(tensor_placeholder_);
462462

463-
if (FLAGS_enable_acl_graph && !is_prefill &&
463+
if (FLAGS_enable_graph && !is_prefill &&
464464
input_params.graph_buffer.tiling_data.defined()) {
465465
node.variantPack.inTensors.at(input_idx++) =
466466
atb_speed::Utils::AtTensor2Tensor(

xllm/core/runtime/CMakeLists.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ include(cc_library)
22
include(cc_binary)
33
include(cc_test)
44

5+
56
cc_library(
67
NAME
78
runtime
@@ -12,6 +13,7 @@ cc_library(
1213
params_utils.h
1314
executor.h
1415
executor_impl.h
16+
executor_impl_factory.h
1517
base_executor_impl.h
1618
dit_executor.h
1719
$<$<BOOL:${USE_NPU}>:acl_graph_executor_impl.h>
@@ -29,6 +31,7 @@ cc_library(
2931
forward_shared_memory_manager.h
3032
SRCS
3133
executor.cpp
34+
executor_impl_factory.cpp
3235
base_executor_impl.cpp
3336
dit_executor.cpp
3437
$<$<BOOL:${USE_NPU}>:acl_graph_executor_impl.cpp>

xllm/core/runtime/acl_graph_executor_impl.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -849,7 +849,7 @@ void AclGraph::print_graph_tensors() const {
849849
// bucket will be [1, 2, 4, 8, 16, 32, 48, 64, ..., max_seqs_per_batch]
850850
uint32_t AclGraphExecutorImpl::get_bucket_num_tokens(
851851
uint32_t num_tokens) const {
852-
if (FLAGS_enable_acl_graph_no_padding) {
852+
if (FLAGS_enable_graph_no_padding) {
853853
return num_tokens;
854854
}
855855
if (num_tokens <= 1) {

0 commit comments

Comments
 (0)