Skip to content

Commit bedc5f5

Browse files
committed
feat: initialize flashinfer planinfo at layer-0 forward stage.
Signed-off-by: pengtao.156 <[email protected]>
1 parent a3bf8b3 commit bedc5f5

46 files changed

Lines changed: 322 additions & 158 deletions

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

xllm/core/framework/model_context.h

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -55,6 +55,15 @@ class ModelContext {
5555
return tensor_options_;
5656
}
5757

58+
void set_layer_id(int32_t layer_id) { layer_id_ = layer_id; }
59+
const int32_t layer_id() const {
60+
if (layer_id_ == -1) {
61+
LOG(ERROR) << "layer_id is not set in ModelContext, layer_id_ = "
62+
<< layer_id_;
63+
}
64+
return layer_id_;
65+
}
66+
5867
#if defined(USE_NPU)
5968
const atb::Context* get_atb_context() const { return context_; }
6069
#endif
@@ -64,6 +73,7 @@ class ModelContext {
6473
}
6574

6675
private:
76+
int32_t layer_id_ = -1;
6777
ModelArgs model_args_;
6878
QuantArgs quant_args_;
6979
ParallelArgs parallel_args_;

xllm/core/kernels/cuda/batch_decode.cpp

Lines changed: 3 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ limitations under the License.
1818

1919
namespace xllm::kernel::cuda {
2020

21-
void batch_decode(torch::Tensor float_workspace_buffer,
21+
void batch_decode(const std::string& uri,
22+
torch::Tensor plan_info,
23+
torch::Tensor float_workspace_buffer,
2224
torch::Tensor int_workspace_buffer,
2325
torch::Tensor page_locked_int_workspace_buffer,
2426
torch::Tensor query,
@@ -32,41 +34,6 @@ void batch_decode(torch::Tensor float_workspace_buffer,
3234
torch::Tensor output,
3335
std::optional<torch::Tensor>& output_lse,
3436
bool enable_cuda_graph) {
35-
std::string uri = get_batch_decode_uri(query.scalar_type(),
36-
k_cache.scalar_type(),
37-
output.scalar_type(),
38-
paged_kv_indptr.scalar_type(),
39-
query.size(-1),
40-
v_cache.size(-1),
41-
/*pos_encoding_mode=*/0,
42-
/*use_sliding_window=*/false,
43-
/*use_logits_soft_cap=*/false);
44-
45-
torch::Tensor paged_kv_indptr_host = paged_kv_indptr.to(torch::kCPU);
46-
const int64_t batch_size = paged_kv_last_page_len.size(0);
47-
48-
torch::Tensor empty_q_data =
49-
torch::empty({0}, torch::TensorOptions().dtype(query.scalar_type()));
50-
torch::Tensor empty_kv_data =
51-
torch::empty({0}, torch::TensorOptions().dtype(k_cache.scalar_type()));
52-
53-
auto plan_info = FunctionFactory::get_instance().decode_plan_func(uri).call(
54-
float_workspace_buffer,
55-
int_workspace_buffer,
56-
page_locked_int_workspace_buffer,
57-
paged_kv_indptr_host,
58-
batch_size,
59-
query.size(1), // num_qo_heads
60-
k_cache.size(2), // num_kv_heads
61-
k_cache.size(1), // block_size
62-
enable_cuda_graph,
63-
window_left,
64-
/*logits_soft_cap=*/0.0,
65-
query.size(-1), // head_dim_qk
66-
v_cache.size(-1), // head_dim_vo
67-
empty_q_data,
68-
empty_kv_data);
69-
7037
FunctionFactory::get_instance().decode_run_func(uri).call(
7138
float_workspace_buffer,
7239
int_workspace_buffer,

xllm/core/kernels/cuda/batch_prefill.cpp

Lines changed: 3 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@ limitations under the License.
1818

1919
namespace xllm::kernel::cuda {
2020

21-
void batch_prefill(torch::Tensor float_workspace_buffer,
21+
void batch_prefill(const std::string& uri,
22+
torch::Tensor plan_info,
23+
torch::Tensor float_workspace_buffer,
2224
torch::Tensor int_workspace_buffer,
2325
torch::Tensor page_locked_int_workspace_buffer,
2426
torch::Tensor query,
@@ -35,43 +37,6 @@ void batch_prefill(torch::Tensor float_workspace_buffer,
3537
determine_attention_backend(/*pos_encoding_mode=*/0,
3638
/*use_fp16_qk_reduction=*/false,
3739
/*use_custom_mask=*/false);
38-
39-
std::string uri = get_batch_prefill_uri(backend,
40-
query.scalar_type(),
41-
key.scalar_type(),
42-
output.scalar_type(),
43-
q_cu_seq_lens.scalar_type(),
44-
query.size(-1),
45-
value.size(-1),
46-
/*pos_encoding_mode=*/0,
47-
/*use_sliding_window=*/false,
48-
/*use_logits_soft_cap=*/false,
49-
/*use_fp16_qk_reduction=*/false);
50-
51-
torch::Tensor qo_indptr_host = q_cu_seq_lens.to(torch::kCPU);
52-
torch::Tensor kv_cu_seq_lens_host = kv_cu_seq_lens.to(torch::kCPU);
53-
torch::Tensor kv_len_arr_host =
54-
kv_cu_seq_lens_host.slice(0, 1) - kv_cu_seq_lens_host.slice(0, 0, -1);
55-
const int64_t total_num_rows = qo_indptr_host[-1].item<int64_t>();
56-
const int64_t batch_size = qo_indptr_host.size(0) - 1;
57-
58-
auto plan_info = FunctionFactory::get_instance().prefill_plan_func(uri).call(
59-
float_workspace_buffer,
60-
int_workspace_buffer,
61-
page_locked_int_workspace_buffer,
62-
qo_indptr_host,
63-
kv_cu_seq_lens_host,
64-
kv_len_arr_host,
65-
total_num_rows,
66-
batch_size,
67-
query.size(1), // num_qo_heads
68-
key.size(1), // num_kv_heads
69-
/*page_size=*/1,
70-
enable_cuda_graph,
71-
query.size(-1), // head_dim_qk
72-
value.size(-1), // head_dim_vo
73-
/*causal=*/true);
74-
7540
if (backend == "fa2") {
7641
FunctionFactory::get_instance().fa2_prefill_ragged_run_func(uri).call(
7742
float_workspace_buffer,

xllm/core/kernels/cuda/cuda_ops_api.h

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,9 @@ void reshape_paged_cache(
4545
torch::Tensor key_cache, // [n_blocks, block_size, n_heads, head_dim]
4646
torch::Tensor value_cache);
4747

48-
void batch_prefill(torch::Tensor float_workspace_buffer,
48+
void batch_prefill(const std::string& uri,
49+
torch::Tensor plan_info,
50+
torch::Tensor float_workspace_buffer,
4951
torch::Tensor int_workspace_buffer,
5052
torch::Tensor page_locked_int_workspace_buffer,
5153
torch::Tensor query,
@@ -59,7 +61,9 @@ void batch_prefill(torch::Tensor float_workspace_buffer,
5961
std::optional<torch::Tensor>& output_lse,
6062
bool enable_cuda_graph);
6163

62-
void batch_decode(torch::Tensor float_workspace_buffer,
64+
void batch_decode(const std::string& uri,
65+
torch::Tensor plan_info,
66+
torch::Tensor float_workspace_buffer,
6367
torch::Tensor int_workspace_buffer,
6468
torch::Tensor page_locked_int_workspace_buffer,
6569
torch::Tensor query,

xllm/core/kernels/cuda/utils.h

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@ limitations under the License.
1515

1616
#pragma once
1717

18+
#include <ATen/DynamicLibrary.h>
1819
#include <torch/torch.h>
1920

2021
#include <string>
@@ -60,4 +61,4 @@ std::string get_batch_decode_uri(torch::ScalarType dtype_q,
6061
bool use_sliding_window,
6162
bool use_logits_soft_cap);
6263

63-
} // namespace xllm::kernel::cuda
64+
} // namespace xllm::kernel::cuda

xllm/core/kernels/ops_api.cpp

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,7 +153,9 @@ void batch_prefill(AttentionParams& params) {
153153
params.scale,
154154
params.output);
155155
#elif defined(USE_CUDA)
156-
cuda::batch_prefill(params.float_workspace_buffer,
156+
cuda::batch_prefill(params.uri,
157+
params.plan_info,
158+
params.float_workspace_buffer,
157159
params.int_workspace_buffer,
158160
params.page_locked_int_workspace_buffer,
159161
params.query,
@@ -225,7 +227,9 @@ void batch_decode(AttentionParams& params) {
225227
#elif defined(USE_CUDA)
226228
params.query = params.query.squeeze(1);
227229
params.output = params.output.squeeze(1);
228-
cuda::batch_decode(params.float_workspace_buffer,
230+
cuda::batch_decode(params.uri,
231+
params.plan_info,
232+
params.float_workspace_buffer,
229233
params.int_workspace_buffer,
230234
params.page_locked_int_workspace_buffer,
231235
params.query,

xllm/core/kernels/param.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,8 @@ struct AttentionParams {
208208
torch::Tensor page_locked_int_workspace_buffer;
209209

210210
bool enable_cuda_graph = false;
211+
std::string uri;
212+
torch::Tensor plan_info;
211213

212214
// ========== Prefill-specific parameters ==========
213215
// Key tensor. Shape: [num_tokens, num_kv_heads, head_dim_qk] (packed) or

xllm/core/layers/common/attention_metadata.cpp

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,9 @@ limitations under the License.
1616
#include "attention_metadata.h"
1717

1818
#include "core/common/global_flags.h"
19+
#include "core/layers/cuda/flashinfer_workspace.h"
20+
#include "kernels/cuda/function_factory.h"
21+
#include "kernels/cuda/utils.h"
1922

2023
namespace xllm {
2124
namespace layer {
@@ -53,4 +56,4 @@ AttentionMetadata AttentionMetadata::build(const ModelInputParams& params,
5356
}
5457

5558
} // namespace layer
56-
} // namespace xllm
59+
} // namespace xllm

xllm/core/layers/common/attention_metadata.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -53,4 +53,4 @@ struct AttentionMetadata {
5353
};
5454

5555
} // namespace layer
56-
} // namespace xllm
56+
} // namespace xllm

xllm/core/layers/common/qwen2_attention.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,8 @@ Qwen2AttentionImpl::Qwen2AttentionImpl(const ModelContext& context) {
103103

104104
// 5. Attention
105105
attn_ = register_module("attn",
106-
Attention(num_heads_,
106+
Attention(context.layer_id(),
107+
num_heads_,
107108
head_dim_,
108109
scaling_,
109110
num_kv_heads_,

0 commit comments

Comments
 (0)