-
Notifications
You must be signed in to change notification settings - Fork 166
perf: optimize qwen3.5 hybrid linear cache flow[4/N]. #1160
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -172,14 +172,15 @@ bool WorkerImpl::allocate_kv_cache( | |
| const std::vector<std::vector<int64_t>>& kv_cache_shape) { | ||
| CHECK(model_ != nullptr) << "Model is not initialized."; | ||
| CHECK(kv_caches_.empty()) << "KV caches are already initialized."; | ||
| const bool enable_linear_attention = | ||
| has_linear_attention_layers(context_.get_model_args()); | ||
| const bool enable_lighting_indexer = | ||
| context_.get_model_args().index_n_heads() > 0; | ||
| const auto& args = context_.get_model_args(); | ||
| const bool enable_linear_attention = has_linear_attention_layers(args); | ||
| const bool enable_lighting_indexer = args.index_n_heads() > 0; | ||
| CHECK(!(enable_linear_attention && enable_lighting_indexer)) | ||
| << "KVCache does not support linear attention and lighting indexer " | ||
| << "simultaneously."; | ||
|
|
||
| const int64_t num_layers = get_num_layers(); | ||
|
|
||
| // Check if KV cache quantization is enabled | ||
| // "auto" (default): cache dtype aligns with model dtype (no quantization) | ||
| // "int8": enables INT8 quantization | ||
|
|
@@ -202,11 +203,12 @@ bool WorkerImpl::allocate_kv_cache( | |
| } | ||
|
|
||
| // create a KVCache for each layer | ||
| const int64_t num_layers = get_num_layers(); | ||
| kv_caches_.reserve(num_layers); | ||
|
|
||
| if (FLAGS_enable_xtensor) { | ||
| // XTensor mode: create xtensor-backed KV cache tensors. | ||
| // For hybrid models, we still create full KV cache for all layers | ||
| // since xtensor has its own memory management | ||
| auto& allocator = XTensorAllocator::get_instance(); | ||
| const std::string& model_id = options_.model_id(); | ||
| // Create K tensors for all layers | ||
|
|
@@ -223,92 +225,132 @@ bool WorkerImpl::allocate_kv_cache( | |
| k_tensor = at_npu::native::npu_format_cast(k_tensor, ACL_FORMAT_ND); | ||
| v_tensor = at_npu::native::npu_format_cast(v_tensor, ACL_FORMAT_ND); | ||
| #endif | ||
|
|
||
| // For xtensor mode, we still use the full KV cache approach | ||
| kv_caches_.emplace_back(k_tensor, v_tensor); | ||
| } | ||
| } else { | ||
| // Original mode: create torch tensors with optional int8 kv quantization. | ||
| torch::ScalarType cache_dtype = | ||
| enable_kv_cache_quant ? torch::kInt8 : dtype_; | ||
|
|
||
| // Helper function to check if a layer is linear attention | ||
| auto is_linear_attention_layer = [&](int64_t layer_idx) { | ||
| if (args.full_attention_interval() > 1) { | ||
| return (layer_idx + 1) % args.full_attention_interval() != 0; | ||
| } | ||
| return false; | ||
| }; | ||
JC-ut0 marked this conversation as resolved.
Show resolved
Hide resolved
|
||
|
|
||
| for (int64_t i = 0; i < num_layers; ++i) { | ||
| bool is_linear_layer = is_linear_attention_layer(i); | ||
| torch::Tensor key_cache, value_cache, index_cache, conv_cache, ssm_cache; | ||
| torch::Tensor key_cache_scale, value_cache_scale; | ||
|
|
||
| if (is_linear_layer) { | ||
| // Linear attention layer: only allocate conv_cache and ssm_cache | ||
| #if defined(USE_NPU) | ||
| aclFormat npu_format_type = | ||
| context_.get_model_args().model_type() == "deepseek_v3" && | ||
| FLAGS_enable_prefix_cache | ||
| ? ACL_FORMAT_FRACTAL_NZ | ||
| : ACL_FORMAT_ND; | ||
| key_cache = at_npu::native::npu_format_cast( | ||
| torch::empty(kv_cache_shape[0], | ||
| torch::dtype(cache_dtype).device(device_)), | ||
| npu_format_type); | ||
| value_cache = at_npu::native::npu_format_cast( | ||
| torch::empty(kv_cache_shape[1], | ||
| torch::dtype(cache_dtype).device(device_)), | ||
| npu_format_type); | ||
| if (enable_lighting_indexer) { | ||
| index_cache = at_npu::native::npu_format_cast( | ||
| torch::empty(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)), | ||
| npu_format_type); | ||
| } | ||
| if (enable_linear_attention) { | ||
| conv_cache = at_npu::native::npu_format_cast( | ||
| torch::zeros(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)), | ||
| 2); | ||
| ssm_cache = at_npu::native::npu_format_cast( | ||
| torch::zeros(kv_cache_shape[3], | ||
| torch::dtype(dtype_).device(device_)), | ||
| 2); | ||
| } | ||
| aclFormat npu_format_type = ACL_FORMAT_ND; | ||
| if (enable_linear_attention) { | ||
| conv_cache = at_npu::native::npu_format_cast( | ||
| torch::zeros(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)), | ||
| 2); | ||
| ssm_cache = at_npu::native::npu_format_cast( | ||
| torch::zeros(kv_cache_shape[3], | ||
| torch::dtype(dtype_).device(device_)), | ||
| 2); | ||
| } | ||
| #elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA) | ||
|
Collaborator
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This follows the original code style below. |
||
| key_cache = torch::zeros(kv_cache_shape[0], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| if (!kv_cache_shape[1].empty()) { | ||
| value_cache = torch::zeros(kv_cache_shape[1], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| } | ||
| if (enable_lighting_indexer) { | ||
| index_cache = torch::zeros(kv_cache_shape[2], | ||
| if (enable_linear_attention) { | ||
| conv_cache = torch::zeros(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)); | ||
| ssm_cache = torch::zeros(kv_cache_shape[3], | ||
| torch::dtype(dtype_).device(device_)); | ||
| } | ||
| if (enable_kv_cache_quant) { | ||
| std::vector<int64_t> key_scale_shape(kv_cache_shape[0].begin(), | ||
| kv_cache_shape[0].end() - 1); | ||
| key_cache_scale = torch::zeros( | ||
| key_scale_shape, torch::dtype(torch::kFloat32).device(device_)); | ||
| if (!kv_cache_shape[1].empty()) { | ||
| std::vector<int64_t> value_scale_shape(kv_cache_shape[1].begin(), | ||
| kv_cache_shape[1].end() - 1); | ||
| value_cache_scale = torch::zeros( | ||
| value_scale_shape, torch::dtype(torch::kFloat32).device(device_)); | ||
| } | ||
| } | ||
| #else | ||
| key_cache = torch::empty(kv_cache_shape[0], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| if (!kv_cache_shape[1].empty()) { | ||
| value_cache = torch::empty(kv_cache_shape[1], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| } | ||
| if (enable_lighting_indexer) { | ||
| index_cache = torch::empty(kv_cache_shape[2], | ||
| if (enable_linear_attention) { | ||
| conv_cache = torch::empty(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)); | ||
| ssm_cache = torch::empty(kv_cache_shape[3], | ||
| torch::dtype(dtype_).device(device_)); | ||
| } | ||
| } | ||
| #endif | ||
| if (enable_kv_cache_quant) { | ||
| kv_caches_.emplace_back(key_cache, | ||
| value_cache, | ||
| index_cache, | ||
| key_cache_scale, | ||
| value_cache_scale); | ||
| } else if (enable_linear_attention) { | ||
| kv_caches_.emplace_back(key_cache, value_cache, conv_cache, ssm_cache); | ||
| } else if (enable_lighting_indexer) { | ||
| kv_caches_.emplace_back(key_cache, value_cache, index_cache); | ||
| // Create empty KVCache with only conv and ssm | ||
| kv_caches_.emplace_back( | ||
| torch::empty({0}, torch::dtype(dtype_).device(device_)), | ||
| torch::empty({0}, torch::dtype(dtype_).device(device_)), | ||
| conv_cache, | ||
| ssm_cache); | ||
| } else { | ||
| kv_caches_.emplace_back(key_cache, value_cache); | ||
| // Full attention layer: allocate key_cache and value_cache only | ||
| #if defined(USE_NPU) | ||
| aclFormat npu_format_type = | ||
| context_.get_model_args().model_type() == "deepseek_v3" && | ||
| FLAGS_enable_prefix_cache | ||
| ? ACL_FORMAT_FRACTAL_NZ | ||
| : ACL_FORMAT_ND; | ||
| key_cache = at_npu::native::npu_format_cast( | ||
| torch::empty(kv_cache_shape[0], | ||
| torch::dtype(cache_dtype).device(device_)), | ||
| npu_format_type); | ||
| value_cache = at_npu::native::npu_format_cast( | ||
| torch::empty(kv_cache_shape[1], | ||
| torch::dtype(cache_dtype).device(device_)), | ||
| npu_format_type); | ||
| if (enable_lighting_indexer) { | ||
| index_cache = at_npu::native::npu_format_cast( | ||
| torch::empty(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)), | ||
| npu_format_type); | ||
| } | ||
| #elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA) | ||
| key_cache = torch::zeros(kv_cache_shape[0], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| if (!kv_cache_shape[1].empty()) { | ||
| value_cache = torch::zeros(kv_cache_shape[1], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| } | ||
| if (enable_lighting_indexer) { | ||
| index_cache = torch::zeros(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)); | ||
| } | ||
| if (enable_kv_cache_quant) { | ||
| std::vector<int64_t> key_scale_shape(kv_cache_shape[0].begin(), | ||
| kv_cache_shape[0].end() - 1); | ||
| key_cache_scale = torch::zeros( | ||
| key_scale_shape, torch::dtype(torch::kFloat32).device(device_)); | ||
| if (!kv_cache_shape[1].empty()) { | ||
| std::vector<int64_t> value_scale_shape(kv_cache_shape[1].begin(), | ||
| kv_cache_shape[1].end() - 1); | ||
| value_cache_scale = | ||
| torch::zeros(value_scale_shape, | ||
| torch::dtype(torch::kFloat32).device(device_)); | ||
| } | ||
| } | ||
| #else | ||
| key_cache = torch::empty(kv_cache_shape[0], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| if (!kv_cache_shape[1].empty()) { | ||
| value_cache = torch::empty(kv_cache_shape[1], | ||
| torch::dtype(cache_dtype).device(device_)); | ||
| } | ||
| if (enable_lighting_indexer) { | ||
| index_cache = torch::empty(kv_cache_shape[2], | ||
| torch::dtype(dtype_).device(device_)); | ||
| } | ||
| #endif | ||
| if (enable_kv_cache_quant) { | ||
| kv_caches_.emplace_back(key_cache, | ||
| value_cache, | ||
| index_cache, | ||
| key_cache_scale, | ||
| value_cache_scale); | ||
| } else if (enable_lighting_indexer) { | ||
| kv_caches_.emplace_back(key_cache, value_cache, index_cache); | ||
| } else { | ||
| kv_caches_.emplace_back(key_cache, value_cache); | ||
| } | ||
| } | ||
| } | ||
| } | ||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.