Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 16 additions & 2 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -519,12 +519,26 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {
}
#endif

int64_t full_attention_interval = (args_.full_attention_interval() < 1)
? 1
: args_.full_attention_interval();
int64_t num_full_attention_layers =
kv_cache_cap.n_layers / full_attention_interval;
int64_t num_linear_attention_layers =
kv_cache_cap.n_layers - num_full_attention_layers;
// compute kv cache n_blocks
const int32_t block_size = options_.block_size();
const int64_t block_size_in_bytes =
block_size * (slot_size + index_slot_size + scale_slot_size);
kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes /
(kv_cache_cap.n_layers * block_size_in_bytes);
const int64_t full_cache_block_size_in_bytes =
block_size * (slot_size + index_slot_size + scale_slot_size);
const int64_t total_cache_block_size_in_bytes =
num_full_attention_layers * full_cache_block_size_in_bytes +
num_linear_attention_layers * linear_slot_size;
CHECK_GT(total_cache_block_size_in_bytes, 0)
<< "invalid cache block size estimate";
kv_cache_cap.n_blocks =
kv_cache_cap.cache_size_in_bytes / total_cache_block_size_in_bytes;
CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache";
return kv_cache_cap;
}
Expand Down
35 changes: 24 additions & 11 deletions xllm/core/runtime/acl_graph_executor_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,19 @@ limitations under the License.
namespace xllm::npu {

namespace {
std::pair<torch::Tensor, torch::Tensor> find_attention_plan_kv_cache(
const std::vector<KVCache>& kv_caches) {
for (const auto& cache : kv_caches) {
auto k_cache = cache.get_k_cache();
auto v_cache = cache.get_v_cache();
if (k_cache.defined() && v_cache.defined() && k_cache.numel() > 0 &&
v_cache.numel() > 0) {
return {std::move(k_cache), std::move(v_cache)};
}
}
return {torch::Tensor(), torch::Tensor()};
}

int64_t get_decode_graph_capacity(const runtime::Options& options) {
CHECK_GT(options.num_decoding_tokens(), 0)
<< "num_decoding_tokens must be > 0 for graph capacity";
Expand Down Expand Up @@ -265,14 +278,8 @@ std::optional<ModelInputParams> GraphPersistentParam::update(
}

if (tiling_data_.numel() > 0) {
// Get current stream for tiling tensor update
aclrtStream stream = c10_npu::getCurrentNPUStream().stream();

// Update tiling tensor based on model type
// For models with mixed attention types (e.g., qwen3_next), only update if
// k/v cache is defined
// NOTE: linear attention may pass "defined but empty" k/v cache tensors.
// Only treat k/v cache as valid when they are defined and non-empty.
if (need_update_attention_plan_ && k_cache.defined() && v_cache.defined() &&
k_cache.numel() > 0 && v_cache.numel() > 0) {
plan_paged_attention_tiling(
Expand Down Expand Up @@ -798,9 +805,12 @@ bool AclGraph::capture(CausalLM* model,
aclrtStream stream =
c10_npu::getCurrentNPUStream(tensor_options.device().index()).stream();

// Update persistent parameters with input data before capture
const torch::Tensor& k_cache = kv_cache[0].get_k_cache();
const torch::Tensor& v_cache = kv_cache[0].get_v_cache();
// For hybrid models (e.g., qwen3_next with mixed GDN/full_attention layers),
// we need to find the first Full Attention layer to get the correct kv_cache.
// GDN layers have empty key_cache_/value_cache_ while Full Attention layers
// have valid kv caches. Using layer 0's cache directly would be incorrect
// if layer 0 is a GDN layer.
auto [k_cache, v_cache] = find_attention_plan_kv_cache(kv_cache);
const uint32_t actual_num_tokens = tokens.size(0);
CHECK_GE(num_tokens_, actual_num_tokens)
<< "num_tokens_ >= actual_num_tokens";
Expand Down Expand Up @@ -899,8 +909,11 @@ ModelOutput AclGraph::replay(const torch::Tensor& tokens,
<< actual_num_tokens;

// Update persistent parameters with new input data
const torch::Tensor& k_cache = kv_cache[0].get_k_cache();
const torch::Tensor& v_cache = kv_cache[0].get_v_cache();
// Note: tiling_data is updated in update() if needed - for hybrid models
// (e.g., qwen3_next with mixed GDN/attention layers), tiling should only
// be updated when Full Attention layers are involved, which is determined
// by k_cache being valid and non-empty
auto [k_cache, v_cache] = find_attention_plan_kv_cache(kv_cache);
persistent_param_.update(tokens,
k_cache,
v_cache,
Expand Down
190 changes: 116 additions & 74 deletions xllm/core/runtime/worker_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -172,14 +172,15 @@ bool WorkerImpl::allocate_kv_cache(
const std::vector<std::vector<int64_t>>& kv_cache_shape) {
CHECK(model_ != nullptr) << "Model is not initialized.";
CHECK(kv_caches_.empty()) << "KV caches are already initialized.";
const bool enable_linear_attention =
has_linear_attention_layers(context_.get_model_args());
const bool enable_lighting_indexer =
context_.get_model_args().index_n_heads() > 0;
const auto& args = context_.get_model_args();
const bool enable_linear_attention = has_linear_attention_layers(args);
const bool enable_lighting_indexer = args.index_n_heads() > 0;
CHECK(!(enable_linear_attention && enable_lighting_indexer))
<< "KVCache does not support linear attention and lighting indexer "
<< "simultaneously.";

const int64_t num_layers = get_num_layers();

// Check if KV cache quantization is enabled
// "auto" (default): cache dtype aligns with model dtype (no quantization)
// "int8": enables INT8 quantization
Expand All @@ -202,11 +203,12 @@ bool WorkerImpl::allocate_kv_cache(
}

// create a KVCache for each layer
const int64_t num_layers = get_num_layers();
kv_caches_.reserve(num_layers);

if (FLAGS_enable_xtensor) {
// XTensor mode: create xtensor-backed KV cache tensors.
// For hybrid models, we still create full KV cache for all layers
// since xtensor has its own memory management
auto& allocator = XTensorAllocator::get_instance();
const std::string& model_id = options_.model_id();
// Create K tensors for all layers
Expand All @@ -223,92 +225,132 @@ bool WorkerImpl::allocate_kv_cache(
k_tensor = at_npu::native::npu_format_cast(k_tensor, ACL_FORMAT_ND);
v_tensor = at_npu::native::npu_format_cast(v_tensor, ACL_FORMAT_ND);
#endif

// For xtensor mode, we still use the full KV cache approach
kv_caches_.emplace_back(k_tensor, v_tensor);
}
} else {
// Original mode: create torch tensors with optional int8 kv quantization.
torch::ScalarType cache_dtype =
enable_kv_cache_quant ? torch::kInt8 : dtype_;

// Helper function to check if a layer is linear attention
auto is_linear_attention_layer = [&](int64_t layer_idx) {
if (args.full_attention_interval() > 1) {
return (layer_idx + 1) % args.full_attention_interval() != 0;
}
return false;
};

for (int64_t i = 0; i < num_layers; ++i) {
bool is_linear_layer = is_linear_attention_layer(i);
torch::Tensor key_cache, value_cache, index_cache, conv_cache, ssm_cache;
torch::Tensor key_cache_scale, value_cache_scale;

if (is_linear_layer) {
// Linear attention layer: only allocate conv_cache and ssm_cache
#if defined(USE_NPU)
aclFormat npu_format_type =
context_.get_model_args().model_type() == "deepseek_v3" &&
FLAGS_enable_prefix_cache
? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
key_cache = at_npu::native::npu_format_cast(
torch::empty(kv_cache_shape[0],
torch::dtype(cache_dtype).device(device_)),
npu_format_type);
value_cache = at_npu::native::npu_format_cast(
torch::empty(kv_cache_shape[1],
torch::dtype(cache_dtype).device(device_)),
npu_format_type);
if (enable_lighting_indexer) {
index_cache = at_npu::native::npu_format_cast(
torch::empty(kv_cache_shape[2],
torch::dtype(dtype_).device(device_)),
npu_format_type);
}
if (enable_linear_attention) {
conv_cache = at_npu::native::npu_format_cast(
torch::zeros(kv_cache_shape[2],
torch::dtype(dtype_).device(device_)),
2);
ssm_cache = at_npu::native::npu_format_cast(
torch::zeros(kv_cache_shape[3],
torch::dtype(dtype_).device(device_)),
2);
}
aclFormat npu_format_type = ACL_FORMAT_ND;
if (enable_linear_attention) {
conv_cache = at_npu::native::npu_format_cast(
torch::zeros(kv_cache_shape[2],
torch::dtype(dtype_).device(device_)),
2);
ssm_cache = at_npu::native::npu_format_cast(
torch::zeros(kv_cache_shape[3],
torch::dtype(dtype_).device(device_)),
2);
}
#elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA)
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The #elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA) and #else branches appear to have identical behavior — what's the reason for splitting them?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This follows the original code style below.

key_cache = torch::zeros(kv_cache_shape[0],
torch::dtype(cache_dtype).device(device_));
if (!kv_cache_shape[1].empty()) {
value_cache = torch::zeros(kv_cache_shape[1],
torch::dtype(cache_dtype).device(device_));
}
if (enable_lighting_indexer) {
index_cache = torch::zeros(kv_cache_shape[2],
if (enable_linear_attention) {
conv_cache = torch::zeros(kv_cache_shape[2],
torch::dtype(dtype_).device(device_));
ssm_cache = torch::zeros(kv_cache_shape[3],
torch::dtype(dtype_).device(device_));
}
if (enable_kv_cache_quant) {
std::vector<int64_t> key_scale_shape(kv_cache_shape[0].begin(),
kv_cache_shape[0].end() - 1);
key_cache_scale = torch::zeros(
key_scale_shape, torch::dtype(torch::kFloat32).device(device_));
if (!kv_cache_shape[1].empty()) {
std::vector<int64_t> value_scale_shape(kv_cache_shape[1].begin(),
kv_cache_shape[1].end() - 1);
value_cache_scale = torch::zeros(
value_scale_shape, torch::dtype(torch::kFloat32).device(device_));
}
}
#else
key_cache = torch::empty(kv_cache_shape[0],
torch::dtype(cache_dtype).device(device_));
if (!kv_cache_shape[1].empty()) {
value_cache = torch::empty(kv_cache_shape[1],
torch::dtype(cache_dtype).device(device_));
}
if (enable_lighting_indexer) {
index_cache = torch::empty(kv_cache_shape[2],
if (enable_linear_attention) {
conv_cache = torch::empty(kv_cache_shape[2],
torch::dtype(dtype_).device(device_));
ssm_cache = torch::empty(kv_cache_shape[3],
torch::dtype(dtype_).device(device_));
}
}
#endif
if (enable_kv_cache_quant) {
kv_caches_.emplace_back(key_cache,
value_cache,
index_cache,
key_cache_scale,
value_cache_scale);
} else if (enable_linear_attention) {
kv_caches_.emplace_back(key_cache, value_cache, conv_cache, ssm_cache);
} else if (enable_lighting_indexer) {
kv_caches_.emplace_back(key_cache, value_cache, index_cache);
// Create empty KVCache with only conv and ssm
kv_caches_.emplace_back(
torch::empty({0}, torch::dtype(dtype_).device(device_)),
torch::empty({0}, torch::dtype(dtype_).device(device_)),
conv_cache,
ssm_cache);
} else {
kv_caches_.emplace_back(key_cache, value_cache);
// Full attention layer: allocate key_cache and value_cache only
#if defined(USE_NPU)
aclFormat npu_format_type =
context_.get_model_args().model_type() == "deepseek_v3" &&
FLAGS_enable_prefix_cache
? ACL_FORMAT_FRACTAL_NZ
: ACL_FORMAT_ND;
key_cache = at_npu::native::npu_format_cast(
torch::empty(kv_cache_shape[0],
torch::dtype(cache_dtype).device(device_)),
npu_format_type);
value_cache = at_npu::native::npu_format_cast(
torch::empty(kv_cache_shape[1],
torch::dtype(cache_dtype).device(device_)),
npu_format_type);
if (enable_lighting_indexer) {
index_cache = at_npu::native::npu_format_cast(
torch::empty(kv_cache_shape[2],
torch::dtype(dtype_).device(device_)),
npu_format_type);
}
#elif defined(USE_ILU) || defined(USE_MLU) || defined(USE_MUSA)
key_cache = torch::zeros(kv_cache_shape[0],
torch::dtype(cache_dtype).device(device_));
if (!kv_cache_shape[1].empty()) {
value_cache = torch::zeros(kv_cache_shape[1],
torch::dtype(cache_dtype).device(device_));
}
if (enable_lighting_indexer) {
index_cache = torch::zeros(kv_cache_shape[2],
torch::dtype(dtype_).device(device_));
}
if (enable_kv_cache_quant) {
std::vector<int64_t> key_scale_shape(kv_cache_shape[0].begin(),
kv_cache_shape[0].end() - 1);
key_cache_scale = torch::zeros(
key_scale_shape, torch::dtype(torch::kFloat32).device(device_));
if (!kv_cache_shape[1].empty()) {
std::vector<int64_t> value_scale_shape(kv_cache_shape[1].begin(),
kv_cache_shape[1].end() - 1);
value_cache_scale =
torch::zeros(value_scale_shape,
torch::dtype(torch::kFloat32).device(device_));
}
}
#else
key_cache = torch::empty(kv_cache_shape[0],
torch::dtype(cache_dtype).device(device_));
if (!kv_cache_shape[1].empty()) {
value_cache = torch::empty(kv_cache_shape[1],
torch::dtype(cache_dtype).device(device_));
}
if (enable_lighting_indexer) {
index_cache = torch::empty(kv_cache_shape[2],
torch::dtype(dtype_).device(device_));
}
#endif
if (enable_kv_cache_quant) {
kv_caches_.emplace_back(key_cache,
value_cache,
index_cache,
key_cache_scale,
value_cache_scale);
} else if (enable_lighting_indexer) {
kv_caches_.emplace_back(key_cache, value_cache, index_cache);
} else {
kv_caches_.emplace_back(key_cache, value_cache);
}
}
}
}
Expand Down
Loading