Skip to content

Commit ccb4145

Browse files
committed
feat: support graph on mlu device and fix fused norm bug.
1 parent b38a15b commit ccb4145

21 files changed

+691
-289
lines changed

xllm/core/common/global_flags.cpp

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,14 +90,24 @@ DEFINE_bool(enable_acl_graph,
9090

9191
DEFINE_int32(max_seq_len_for_graph_mode,
9292
0,
93-
"Maximum number of tokens per sequence for ACL graph execution. "
93+
"Maximum number of tokens per sequence for graph execution. "
9494
"If 0, use model max_position_embeddings.");
9595

9696
DEFINE_bool(enable_acl_graph_no_padding,
9797
false,
9898
"Whether to enable ACL graph execution for decode phase without "
9999
"padding. If true, graph will be caputured with every actual num "
100100
"tokens, as stride is 1.");
101+
102+
DEFINE_bool(enable_graph,
103+
false,
104+
"Whether to enable graph execution for decode phase.");
105+
106+
DEFINE_bool(enable_graph_no_padding,
107+
false,
108+
"Whether to enable graph execution for decode phase without "
109+
"padding. If true, graph will be caputured with every actual num "
110+
"tokens, as stride is 1.");
101111
// --- vlm config ---
102112

103113
DEFINE_int32(limit_image_per_prompt,

xllm/core/common/global_flags.h

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,10 @@ DECLARE_int32(max_seq_len_for_graph_mode);
8989

9090
DECLARE_bool(enable_acl_graph_no_padding);
9191

92+
DECLARE_bool(enable_graph);
93+
94+
DECLARE_bool(enable_graph_no_padding);
95+
9296
DECLARE_bool(enable_chunked_prefill);
9397

9498
DECLARE_string(master_node_addr);

xllm/core/framework/model/causal_lm.h

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -69,6 +69,10 @@ class CausalLM : public torch::nn::Module {
6969
virtual void set_lm_head(layer::LmHead& head) = 0;
7070
virtual layer::WordEmbedding get_word_embedding() = 0;
7171
virtual void set_word_embedding(layer::WordEmbedding& embedding) = 0;
72+
virtual void skip_mrope() {}
73+
virtual void apply_mrope(const torch::Tensor positions,
74+
torch::Tensor& cos_pos,
75+
torch::Tensor& sin_pos) {}
7276
};
7377

7478
template <typename Model>
@@ -118,6 +122,20 @@ class CausalLMImpl : public CausalLM {
118122

119123
const torch::TensorOptions& options() const override { return options_; }
120124

125+
void skip_mrope() override {
126+
#if defined(USE_MLU)
127+
model_->skip_mrope();
128+
#endif
129+
}
130+
131+
void apply_mrope(const torch::Tensor positions,
132+
torch::Tensor& cos_pos,
133+
torch::Tensor& sin_pos) override {
134+
#if defined(USE_MLU)
135+
model_->apply_mrope(positions, cos_pos, sin_pos);
136+
#endif
137+
}
138+
121139
private:
122140
Model model_;
123141

xllm/core/layers/common/CMakeLists.txt

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ cc_library(
1616
qwen3_moe_decoder_layer.h
1717
linear.h
1818
word_embedding_impl.h
19-
layer_utils.h
2019
indexer.h
2120
deep_ep.h
2221
SRCS
@@ -32,7 +31,6 @@ cc_library(
3231
qwen3_moe_decoder_layer.cpp
3332
linear.cpp
3433
word_embedding_impl.cpp
35-
layer_utils.cpp
3634
indexer.cpp
3735
deep_ep.cpp
3836
DEPS

xllm/core/layers/common/layer_utils.cpp

Lines changed: 0 additions & 39 deletions
This file was deleted.

xllm/core/layers/common/layer_utils.h

Lines changed: 0 additions & 28 deletions
This file was deleted.

xllm/core/layers/common/qwen2_decoder_layer.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,6 @@ limitations under the License.
1717

1818
#include <glog/logging.h>
1919

20-
#include "layer_utils.h"
21-
2220
namespace xllm {
2321
namespace layer {
2422

xllm/core/layers/common/rms_norm.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> RMSNormImpl::forward(
5555
if (residual.has_value()) {
5656
residual.value() = residual.value().reshape({-1, norm_dim_});
5757
if (Device::type_str() == "mlu") {
58-
residual_out = torch::empty_like(residual.value());
58+
residual_out = residual.value();
5959
}
6060
}
6161

@@ -67,6 +67,7 @@ std::tuple<torch::Tensor, std::optional<torch::Tensor>> RMSNormImpl::forward(
6767
fused_layernorm_params.weight = weight_;
6868
fused_layernorm_params.eps = eps_;
6969
fused_layernorm_params.mode = mode_;
70+
fused_layernorm_params.store_output_before_norm = residual_out.has_value();
7071
if (bias_.defined()) {
7172
fused_layernorm_params.beta = bias_;
7273
}

xllm/core/runtime/CMakeLists.txt

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ include(cc_binary)
33
include(cc_test)
44

55
cc_library(
6-
NAME
6+
NAME
77
runtime
88
HDRS
99
options.h
@@ -15,6 +15,7 @@ cc_library(
1515
base_executor_impl.h
1616
dit_executor.h
1717
$<$<BOOL:${USE_NPU}>:acl_graph_executor_impl.h>
18+
$<$<BOOL:${USE_MLU}>:mlu_graph_executor_impl.h>
1819
worker.h
1920
worker_impl.h
2021
llm_worker_impl.h
@@ -36,6 +37,7 @@ cc_library(
3637
base_executor_impl.cpp
3738
dit_executor.cpp
3839
$<$<BOOL:${USE_NPU}>:acl_graph_executor_impl.cpp>
40+
$<$<BOOL:${USE_MLU}>:mlu_graph_executor_impl.cpp>
3941
worker.cpp
4042
worker_impl.cpp
4143
llm_worker_impl.cpp
@@ -82,7 +84,7 @@ cc_library(
8284
)
8385

8486
cc_library(
85-
NAME
87+
NAME
8688
master
8789
HDRS
8890
llm_master.h
@@ -137,8 +139,9 @@ target_link_libraries(acl_graph_executor_test
137139
)
138140
# Use --whole-archive for spdlog to ensure all symbols are available for xllm_kernels
139141
# This resolves the undefined reference issues in xllm_kernels
140-
target_link_options(acl_graph_executor_test PRIVATE
141-
"-Wl,--whole-archive"
142+
target_link_options(acl_graph_executor_test PRIVATE
143+
"-Wl,--whole-archive"
142144
"${CMAKE_BINARY_DIR}/third_party/spdlog/libspdlog.a"
143145
"-Wl,--no-whole-archive")
144146
endif()
147+

xllm/core/runtime/base_executor_impl.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@ torch::Tensor BaseExecutorImpl::run(const torch::Tensor& tokens,
3535
const torch::Tensor& positions,
3636
std::vector<KVCache>& kv_caches,
3737
const ModelInputParams& params) {
38+
COUNTER_INC(num_model_execution_total_eager);
3839
return model_->forward(tokens, positions, kv_caches, params);
3940
}
4041

0 commit comments

Comments
 (0)