@@ -18,7 +18,9 @@ limitations under the License.
1818
1919namespace xllm ::kernel::cuda {
2020
21- void batch_prefill (torch::Tensor float_workspace_buffer,
21+ void batch_prefill (const std::string& uri,
22+ torch::Tensor plan_info,
23+ torch::Tensor float_workspace_buffer,
2224 torch::Tensor int_workspace_buffer,
2325 torch::Tensor page_locked_int_workspace_buffer,
2426 torch::Tensor query,
@@ -35,43 +37,6 @@ void batch_prefill(torch::Tensor float_workspace_buffer,
3537 determine_attention_backend (/* pos_encoding_mode=*/ 0 ,
3638 /* use_fp16_qk_reduction=*/ false ,
3739 /* use_custom_mask=*/ false );
38-
39- std::string uri = get_batch_prefill_uri (backend,
40- query.scalar_type (),
41- key.scalar_type (),
42- output.scalar_type (),
43- q_cu_seq_lens.scalar_type (),
44- query.size (-1 ),
45- value.size (-1 ),
46- /* pos_encoding_mode=*/ 0 ,
47- /* use_sliding_window=*/ false ,
48- /* use_logits_soft_cap=*/ false ,
49- /* use_fp16_qk_reduction=*/ false );
50-
51- torch::Tensor qo_indptr_host = q_cu_seq_lens.to (torch::kCPU );
52- torch::Tensor kv_cu_seq_lens_host = kv_cu_seq_lens.to (torch::kCPU );
53- torch::Tensor kv_len_arr_host =
54- kv_cu_seq_lens_host.slice (0 , 1 ) - kv_cu_seq_lens_host.slice (0 , 0 , -1 );
55- const int64_t total_num_rows = qo_indptr_host[-1 ].item <int64_t >();
56- const int64_t batch_size = qo_indptr_host.size (0 ) - 1 ;
57-
58- auto plan_info = FunctionFactory::get_instance ().prefill_plan_func (uri).call (
59- float_workspace_buffer,
60- int_workspace_buffer,
61- page_locked_int_workspace_buffer,
62- qo_indptr_host,
63- kv_cu_seq_lens_host,
64- kv_len_arr_host,
65- total_num_rows,
66- batch_size,
67- query.size (1 ), // num_qo_heads
68- key.size (1 ), // num_kv_heads
69- /* page_size=*/ 1 ,
70- enable_cuda_graph,
71- query.size (-1 ), // head_dim_qk
72- value.size (-1 ), // head_dim_vo
73- /* causal=*/ true );
74-
7540 if (backend == " fa2" ) {
7641 FunctionFactory::get_instance ().fa2_prefill_ragged_run_func (uri).call (
7742 float_workspace_buffer,
0 commit comments