Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
66 changes: 58 additions & 8 deletions xllm/core/distributed_runtime/llm_engine.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@ limitations under the License.
#include "common/metrics.h"
#include "common/options.h"
#include "framework/block/hierarchy_block_manager_pool.h"
#if defined(USE_NPU)
#include "framework/kv_cache/hybrid_cache_utils.h"
#endif
#include "framework/model/model_args.h"
#include "framework/model_loader.h"
#include "framework/xtensor/page_allocator.h"
Expand Down Expand Up @@ -520,10 +523,39 @@ Engine::KVCacheCapacity LLMEngine::estimate_kv_cache_capacity() {

// compute kv cache n_blocks
const int32_t block_size = options_.block_size();
#if defined(USE_NPU)
const bool enable_lighting_indexer = args_.index_n_heads() > 1;
const bool enable_gdn_attention = has_linear_attention_layers(args_);
CHECK(!(enable_lighting_indexer && enable_gdn_attention))
<< "KVCache does not support linear attention and lighting indexer "
"enabled at the same time.";
const bool optimize_hybrid_linear_cache =
should_enable_hybrid_linear_cache(args_,
options_.enable_disagg_pd(),
options_.enable_kvcache_store(),
options_.host_blocks_factor());
int64_t num_full_attention_layers =
count_hybrid_full_attention_layers(args_, optimize_hybrid_linear_cache);
int64_t num_linear_attention_layers = 0;
if (linear_slot_size > 0) {
num_linear_attention_layers = count_hybrid_linear_attention_layers(
args_, optimize_hybrid_linear_cache);
}
const int64_t full_cache_block_size_in_bytes =
block_size * (slot_size + index_slot_size + scale_slot_size);
const int64_t total_cache_block_size_in_bytes =
num_full_attention_layers * full_cache_block_size_in_bytes +
num_linear_attention_layers * linear_slot_size;
CHECK_GT(total_cache_block_size_in_bytes, 0)
<< "invalid cache block size estimate";
kv_cache_cap.n_blocks =
kv_cache_cap.cache_size_in_bytes / total_cache_block_size_in_bytes;
#else
const int64_t block_size_in_bytes =
block_size * (slot_size + index_slot_size + scale_slot_size);
kv_cache_cap.n_blocks = kv_cache_cap.cache_size_in_bytes /
(kv_cache_cap.n_layers * block_size_in_bytes);
#endif
CHECK_GT(kv_cache_cap.n_blocks, 0) << "no n_blocks for kv cache";
return kv_cache_cap;
}
Expand All @@ -540,6 +572,11 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
const int32_t block_size = options_.block_size();
bool enable_lighting_indexer = args_.index_n_heads() > 1;
bool enable_gdn_attention = has_linear_attention_layers(args_);
#if defined(USE_NPU)
CHECK(!(enable_lighting_indexer && enable_gdn_attention))
<< "KVCache does not support linear attention and lighting indexer "
"enabled at the same time.";
#endif
Comment on lines 573 to +579
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

There is duplicated logic for checking attention feature conflicts. The same check for enable_lighting_indexer and enable_gdn_attention is also present in estimate_kv_cache_capacity (lines 527-531). This duplication can lead to maintenance issues where one instance is updated but the other is missed.

Consider refactoring this logic to avoid duplication. For example, you could calculate these boolean flags once in init_model() and store them as member variables of LLMEngine. The check could then be performed once, and both estimate_kv_cache_capacity and allocate_kv_cache can use these member variables.


// init kv cache for each worker
std::vector<std::vector<int64_t>> kv_cache_shape;
Expand Down Expand Up @@ -609,18 +646,31 @@ bool LLMEngine::allocate_kv_cache(const Engine::KVCacheCapacity& kv_cache_cap) {
std::swap(shape[1], shape[2]);
}
#endif
LOG(INFO) << "Initializing k cache with shape: [" << kv_cache_shape[0] << "]";
LOG(INFO) << "Initializing v cache with shape: [" << kv_cache_shape[1] << "]";
// kv_cache_shape layout: [k, v, optional indexer, optional conv, optional
// ssm]
constexpr size_t kKeyCacheShapeIdx = 0;
constexpr size_t kValueCacheShapeIdx = 1;
constexpr size_t kFirstOptionalCacheShapeIdx = 2;
const size_t index_cache_shape_idx = kFirstOptionalCacheShapeIdx;
const size_t conv_cache_shape_idx =
kFirstOptionalCacheShapeIdx +
static_cast<size_t>(enable_lighting_indexer);
const size_t ssm_cache_shape_idx = conv_cache_shape_idx + 1;

LOG(INFO) << "Initializing k cache with shape: ["
<< kv_cache_shape[kKeyCacheShapeIdx] << "]";
LOG(INFO) << "Initializing v cache with shape: ["
<< kv_cache_shape[kValueCacheShapeIdx] << "]";
if (enable_lighting_indexer) {
LOG(INFO) << "Initializing indexer cache with shape: [" << kv_cache_shape[2]
<< "]";
LOG(INFO) << "Initializing indexer cache with shape: ["
<< kv_cache_shape[index_cache_shape_idx] << "]";
}
if (enable_gdn_attention) {
LOG(INFO) << "GND Attention is enabled";
LOG(INFO) << "Initializing conv cache with shape: [" << kv_cache_shape[2]
<< "]";
LOG(INFO) << "Initializing ssm cache with shape: [" << kv_cache_shape[3]
<< "]";
LOG(INFO) << "Initializing conv cache with shape: ["
<< kv_cache_shape[conv_cache_shape_idx] << "]";
LOG(INFO) << "Initializing ssm cache with shape: ["
<< kv_cache_shape[ssm_cache_shape_idx] << "]";
}

// initialize block manager
Expand Down
26 changes: 26 additions & 0 deletions xllm/core/framework/batch/batch_input_builder.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@ limitations under the License.
#include "common/global_flags.h"
#include "common/metrics.h"
#include "framework/batch/mposition.h"
#include "framework/kv_cache/hybrid_cache_utils.h"
#include "framework/model/model_args.h"
#include "framework/model/model_input_params.h"
#include "framework/request/sequence.h"
Expand Down Expand Up @@ -73,9 +74,15 @@ BatchInputBuilder::BatchInputBuilder(
state_.flatten_positions_vec.reserve(1000);
state_.mrope_positions_vec.reserve(sequences.size());
state_.block_tables_vec.reserve(sequences.size());
state_.linear_block_tables_vec.reserve(sequences.size());
state_.acc_logprob_vec.reserve(sequences.size());
if (args_ != nullptr) {
use_mrope_ = (args_->rope_scaling_rope_type() == "mrope");
enable_hybrid_linear_cache_ =
should_enable_hybrid_linear_cache(*args_,
FLAGS_enable_disagg_pd,
FLAGS_enable_kvcache_store,
FLAGS_host_blocks_factor);
}
write_block_ids_.clear();
state_.batch_forward_type = batch_forward_type;
Expand Down Expand Up @@ -179,6 +186,9 @@ void BatchInputBuilder::process_sequences_multithreaded() {
state_.block_tables_vec.insert(state_.block_tables_vec.end(),
state.block_tables_vec.begin(),
state.block_tables_vec.end());
state_.linear_block_tables_vec.insert(state_.linear_block_tables_vec.end(),
state.linear_block_tables_vec.begin(),
state.linear_block_tables_vec.end());
state_.acc_logprob_vec.insert(state_.acc_logprob_vec.end(),
state.acc_logprob_vec.begin(),
state.acc_logprob_vec.end());
Expand Down Expand Up @@ -491,6 +501,15 @@ void BatchInputBuilder::setup_kv_cache_info(
}

state.block_tables_vec.emplace_back(std::move(block_ids));
const auto& full_block_table = state.block_tables_vec.back();
if (args_ != nullptr && should_use_compact_linear_block_table(
*args_, enable_hybrid_linear_cache_)) {
state.linear_block_tables_vec.emplace_back(
build_hybrid_linear_state_block_table(
full_block_table, n_kv_cache_tokens, seq_len, block_size));
} else {
state.linear_block_tables_vec.push_back(full_block_table);
}
}

void BatchInputBuilder::padding_decode_batch_size(
Expand Down Expand Up @@ -525,6 +544,7 @@ void BatchInputBuilder::padding_decode_batch_size(
num_decoding_tokens);
#endif
state_.block_tables_vec.emplace_back();
state_.linear_block_tables_vec.emplace_back();
state_.paged_kv_indices.push_back(0);
state_.paged_kv_indptr.push_back(state_.paged_kv_indptr.back() + 1);
state_.paged_kv_last_page_len.push_back(1);
Expand Down Expand Up @@ -580,6 +600,9 @@ ForwardInput BatchInputBuilder::state_to_forward_input() {
util::pad_2d_vector(state_.block_tables_vec, /*pad_value=*/0);
input_params.block_tables =
create_2d_tensor(state_.block_tables_vec, torch::kInt);
util::pad_2d_vector(state_.linear_block_tables_vec, /*pad_value=*/0);
input_params.linear_block_tables =
create_2d_tensor(state_.linear_block_tables_vec, torch::kInt);

if (input_embeddings_vec_.size() != 0) {
input_params.input_embedding = torch::cat(input_embeddings_vec_);
Expand Down Expand Up @@ -655,6 +678,9 @@ RawForwardInput BatchInputBuilder::state_to_raw_forward_input() {
raw_forward_input.new_token_slot_ids = std::move(state_.new_token_slot_ids);
util::pad_2d_vector(state_.block_tables_vec, /*pad_value=*/0);
raw_forward_input.block_tables_vec = std::move(state_.block_tables_vec);
util::pad_2d_vector(state_.linear_block_tables_vec, /*pad_value=*/0);
raw_forward_input.linear_block_tables_vec =
std::move(state_.linear_block_tables_vec);
raw_forward_input.num_sequences = num_sequences_;
// raw_forward_input.dp_global_token_nums = ;
raw_forward_input.transfer_kv_infos = std::move(state_.transfer_kv_infos);
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/framework/batch/batch_input_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -96,6 +96,7 @@ class BatchInputBuilder {
// Cache and block data
std::vector<int32_t> new_token_slot_ids;
std::vector<std::vector<int32_t>> block_tables_vec;
std::vector<std::vector<int32_t>> linear_block_tables_vec;

// beam search kernel input
std::vector<float> acc_logprob_vec;
Expand Down Expand Up @@ -148,6 +149,7 @@ class BatchInputBuilder {

// Configuration
bool use_mrope_ = false;
bool enable_hybrid_linear_cache_ = false;
uint32_t num_sequences_ = 0;
bool need_unique_tokens_ = true;
int32_t cp_size_ = 1;
Expand Down
Loading
Loading