Skip to content
2 changes: 1 addition & 1 deletion third_party/xllm_atb_layers
Submodule xllm_atb_layers updated from 918c03 to cff774
5 changes: 5 additions & 0 deletions xllm/core/common/global_flags.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,11 @@ DEFINE_string(npu_kernel_backend,
DEFINE_bool(enable_intralayer_addnorm,
false,
"enable fused intralayer addnorm ops.");

DEFINE_bool(enable_fia,
false,
"Whether to enable FIA when initializing the NPU qwen2 decoder "
"layer.");
#endif

// --- multi-step decode config ---
Expand Down
2 changes: 2 additions & 0 deletions xllm/core/common/global_flags.h
Original file line number Diff line number Diff line change
Expand Up @@ -310,6 +310,8 @@ DECLARE_uint32(rec_worker_max_concurrency);
DECLARE_string(npu_kernel_backend);

DECLARE_bool(enable_intralayer_addnorm);

DECLARE_bool(enable_fia);
#endif

DECLARE_int32(health_check_interval_ms);
1 change: 1 addition & 0 deletions xllm/core/common/help_formatter.h
Original file line number Diff line number Diff line change
Expand Up @@ -106,6 +106,7 @@ const OptionCategory kOtherOptions = {
"num_request_handling_threads",
"num_response_handling_threads",
"prefill_scheduling_memory_usage_threshold",
"enable_fia",
"use_contiguous_input_buffer",
"use_audio_in_video"}};

Expand Down
121 changes: 103 additions & 18 deletions xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ limitations under the License.
#include <glog/logging.h>
#include <mstx/ms_tools_ext.h>

#include <algorithm>
#include <map>

#include "common/global_flags.h"
Expand All @@ -30,6 +31,10 @@ limitations under the License.
namespace xllm {
namespace layer {

namespace {
constexpr int64_t kFiaMaskSeqLen = 2048;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding kFiaMaskSeqLen to 2048 limits the maximum sequence length supported when FIA is enabled. If the input sequence length exceeds 2048, the fia_attn_mask_ (initialized at line 229) will have incorrect dimensions for the attention operation, potentially leading to incorrect results or crashes. This should be made dynamic or set to a sufficiently large value supported by the model's context window.

}

enum DecoderLayerTensorId : int {
IN_NORM_WEIGHT = 0, // weight
IN_NORM_BIAS = 1, // bias
Expand Down Expand Up @@ -117,13 +122,13 @@ void NpuQwen2DecoderLayerImpl::param_from_args(
static_cast<int>(TransposeType::INVALID),
static_cast<int>(TransposeType::NOT_TRANSPOSE)};
param.kvQuant = false;
param.blockSize = static_cast<int64_t>(FLAGS_block_size);
param.quantGroupSize = 0;
param.rmsNormEps = args.rms_norm_eps();
param.worldSize = parallel_args.world_size();
param.numAttentionHeadsPerRank = args.n_heads() / param.worldSize;
param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads();
param.enableIntraLayerAddNorm = false;
param.enableInterLayerAddNorm = false;
// param.numKeyValueHeadsPerRank = args.n_kv_heads();
std::optional<long int> optionalValue = args.n_kv_heads();
param.numKeyValueHeadsPerRank =
Expand All @@ -134,6 +139,7 @@ void NpuQwen2DecoderLayerImpl::param_from_args(
param.rank = parallel_args.rank();
param.backend = "lccl";
param.enableLogN = false;
param.isFIA = isPrefill && FLAGS_enable_fia;
}

NpuQwen2DecoderLayerImpl::NpuQwen2DecoderLayerImpl(const ModelContext& context)
Expand Down Expand Up @@ -215,10 +221,72 @@ int64_t NpuQwen2DecoderLayerImpl::init_attn_mask() {
torch::Dtype dtype =
prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16;
decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype);
if (FLAGS_enable_fia) {
const auto fia_mask_options =
torch::TensorOptions().dtype(torch::kBool).device(device_);
fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen},
fia_mask_options),
/*diagonal=*/1)
.contiguous();
}

return atb::NO_ERROR;
}

void NpuQwen2DecoderLayerImpl::build_fia_index_tensors(
const ModelInputParams& input_params,
int64_t total_tokens) {
CHECK_GT(input_params.num_sequences, 0)
<< "FIA prefill requires num_sequences > 0";
CHECK_EQ(input_params.q_seq_lens_vec.size(),
static_cast<size_t>(input_params.num_sequences))
<< "q_seq_lens_vec size mismatch for FIA prefill";
CHECK_GT(total_tokens, 0) << "FIA prefill requires non-empty hidden states";

const int32_t target_seq_len = *std::max_element(
input_params.q_seq_lens_vec.begin(), input_params.q_seq_lens_vec.end());
CHECK_GT(target_seq_len, 0) << "FIA prefill requires target_seq_len > 0";

std::vector<int32_t> padding_idx(input_params.num_sequences * target_seq_len);
std::vector<int32_t> unpadding_idx;
unpadding_idx.reserve(total_tokens);

int64_t source_token_offset = 0;
for (int32_t seq_idx = 0; seq_idx < input_params.num_sequences; ++seq_idx) {
const int32_t seq_len = input_params.q_seq_lens_vec[seq_idx];
CHECK_GT(seq_len, 0) << "FIA sequence length must be positive";
CHECK_LE(seq_len, target_seq_len)
<< "FIA sequence length must be <= target_seq_len";
CHECK_LE(source_token_offset + seq_len, total_tokens)
<< "FIA sequence lengths exceed current hidden-state tokens";

const int32_t padded_offset = seq_idx * target_seq_len;
const int32_t fallback_idx = static_cast<int32_t>(source_token_offset);
for (int32_t pos = 0; pos < target_seq_len; ++pos) {
const int32_t padded_idx = padded_offset + pos;
if (pos < seq_len) {
padding_idx[padded_idx] =
static_cast<int32_t>(source_token_offset + pos);
unpadding_idx.push_back(padded_idx);
} else {
padding_idx[padded_idx] = fallback_idx;
}
}
source_token_offset += seq_len;
}

CHECK_EQ(source_token_offset, total_tokens)
<< "FIA padding indices do not cover all hidden-state tokens";
CHECK_EQ(unpadding_idx.size(), static_cast<size_t>(total_tokens))
<< "FIA unpadding indices size mismatch";

const auto index_options = torch::TensorOptions().dtype(torch::kInt32);
fia_padding_idx_ =
torch::tensor(padding_idx, index_options).to(device_).contiguous();
fia_unpadding_idx_ =
torch::tensor(unpadding_idx, index_options).to(device_).contiguous();
}

int64_t NpuQwen2DecoderLayerImpl::init_node(
atb_speed::Model::Node& node,
atb_speed::qwen::DecoderLayerParam& param) {
Expand Down Expand Up @@ -300,34 +368,51 @@ void NpuQwen2DecoderLayerImpl::build_node_variant_pack(
KVCache& kv_cache,
ModelInputParams& input_params,
bool is_prefill) {
if (is_prefill) {
*prefill_param_.bs = std::max(1, input_params.num_sequences);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The pointer prefill_param_.bs is dereferenced here without being initialized to point to valid memory. This will cause a null pointer dereference or memory corruption at runtime. Ensure that bs is properly allocated (e.g., in the constructor) before this assignment.

}

auto* effective_attn_mask = &attn_mask;
if (is_prefill && FLAGS_enable_fia) {
effective_attn_mask = &fia_attn_mask_;
// build_fia_index_tensors(input_params, x.size(0)); // bsnd
}

internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_;
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) =
size_t input_offset = WEIGHT_COUNT_PER_LAYER;
node.variantPack.inTensors.at(input_offset++) = internal_tensors_;
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(cos_pos);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) =
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(sin_pos);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) =
atb_speed::Utils::AtTensor2Tensor(attn_mask);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 4) =
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask);
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(kv_cache.get_k_cache());
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 5) =
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(kv_cache.get_v_cache());
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) =
node.variantPack.inTensors.at(input_offset) =
atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6).hostData =
node.variantPack.inTensors.at(input_offset++).hostData =
input_params.kv_seq_lens_vec.data();
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7) = placeholder_;
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7).hostData =
node.variantPack.inTensors.at(input_offset) = placeholder_;
node.variantPack.inTensors.at(input_offset++).hostData =
placeholder_vec_.data();
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8) = placeholder_;
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9) =
node.variantPack.inTensors.at(input_offset++) = placeholder_;
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(input_params.block_tables);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10) =
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
if (is_prefill && FLAGS_enable_chunked_prefill) {
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) =
if (is_prefill && FLAGS_enable_fia) {
node.variantPack.inTensors.at(input_offset++) = // bsnd padding_idx
placeholder_;
node.variantPack.inTensors.at(input_offset++) = // bsnd unpadding_idx
placeholder_;
}
Comment on lines +406 to +411
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

The FIA index tensors are not being correctly passed to the variant pack; placeholder_ is used instead of the actual index tensors fia_padding_idx_ and fia_unpadding_idx_. Additionally, the call to build_fia_index_tensors is commented out at line 381. This implementation is incomplete and will cause the FIA feature to fail. Note that calling build_fia_index_tensors inside build_node_variant_pack may have performance implications due to repeated device allocations and H2D transfers; consider caching these tensors if the sequence lengths haven't changed.

  if (is_prefill && FLAGS_enable_fia) {
    build_fia_index_tensors(input_params, x.size(0));
    node.variantPack.inTensors.at(input_offset++) = 
        atb_speed::Utils::AtTensor2Tensor(fia_padding_idx_);
    node.variantPack.inTensors.at(input_offset++) = 
        atb_speed::Utils::AtTensor2Tensor(fia_unpadding_idx_);
  }

if (is_prefill && (FLAGS_enable_fia || FLAGS_enable_chunked_prefill)) {
node.variantPack.inTensors.at(input_offset) =
atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11).hostData =
node.variantPack.inTensors.at(input_offset).hostData =
input_params.q_seq_lens_vec.data();
}

Expand Down
6 changes: 6 additions & 0 deletions xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,9 @@ class NpuQwen2DecoderLayerImpl : public BaseLayer {
private:
void initialize_quantization_parameters();

void build_fia_index_tensors(const ModelInputParams& input_params,
int64_t total_tokens);

void build_node_variant_pack(atb_speed::Model::Node& node,
torch::Tensor& x,
torch::Tensor& cos_pos,
Expand Down Expand Up @@ -96,6 +99,9 @@ class NpuQwen2DecoderLayerImpl : public BaseLayer {
atb::Tensor placeholder_;

at::Tensor decode_attn_mask_;
at::Tensor fia_attn_mask_;
at::Tensor fia_padding_idx_;
at::Tensor fia_unpadding_idx_;

at::Tensor at_placeholder_;

Expand Down
30 changes: 27 additions & 3 deletions xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ limitations under the License.
#include <glog/logging.h>
#include <mstx/ms_tools_ext.h>

#include <algorithm>
#include <map>

#include "common/global_flags.h"
Expand All @@ -30,6 +31,10 @@ limitations under the License.
namespace xllm {
namespace layer {

namespace {
constexpr int64_t kFiaMaskSeqLen = 2048;
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

high

Hardcoding kFiaMaskSeqLen to 2048 will cause failures for sequences longer than this limit when FIA is enabled, as the attention mask will have incorrect dimensions.

}

const uint64_t WEIGHT_COUNT_PER_LAYER = 56;

void NpuQwen3DecoderLayerImpl::param_from_args(
Expand All @@ -48,6 +53,8 @@ void NpuQwen3DecoderLayerImpl::param_from_args(
param.rmsnormQKNorm = true;
param.isPrefill = isPrefill;
param.isBF16 = args.dtype() == "bfloat16";
param.isFIA = isPrefill && FLAGS_enable_fia;
param.blockSize = static_cast<int64_t>(FLAGS_block_size);
param.enableSplitFuse = FLAGS_enable_chunked_prefill && isPrefill;
param.loraEnableGMM = false;
param.enableXattention = is_rec_multi_round_mode();
Expand Down Expand Up @@ -197,6 +204,14 @@ int64_t NpuQwen3DecoderLayerImpl::init_attn_mask() {
torch::Dtype dtype =
prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16;
decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype);
if (prefill_param_.isFIA) {
const auto fia_mask_options =
torch::TensorOptions().dtype(torch::kBool).device(device_);
fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen},
fia_mask_options),
/*diagonal=*/1)
.contiguous();
}

return atb::NO_ERROR;
}
Expand Down Expand Up @@ -279,14 +294,23 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack(
ModelInputParams& input_params,
bool is_prefill,
int node_id) {
if (is_prefill) {
*prefill_param_.bs = std::max(1, input_params.num_sequences);
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

critical

The pointer prefill_param_.bs is dereferenced without initialization. This is a critical issue that will lead to a crash at runtime.

}

auto* effective_attn_mask = &attn_mask;
if (is_prefill && prefill_param_.isFIA) {
effective_attn_mask = &fia_attn_mask_;
}

internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_;
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) =
atb_speed::Utils::AtTensor2Tensor(cos_pos);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) =
atb_speed::Utils::AtTensor2Tensor(sin_pos);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) =
atb_speed::Utils::AtTensor2Tensor(attn_mask);
atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) =
atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens);
node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6).hostData =
Expand Down Expand Up @@ -334,8 +358,8 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack(
atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots);
}

if (is_prefill &&
(FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) {
if (is_prefill && (prefill_param_.isFIA || FLAGS_enable_chunked_prefill ||
FLAGS_enable_prefix_cache)) {
node.variantPack.inTensors.at(input_idx++) =
atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens);
node.variantPack.inTensors.at(input_idx - 1).hostData =
Expand Down
1 change: 1 addition & 0 deletions xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,6 +98,7 @@ class NpuQwen3DecoderLayerImpl : public BaseLayer {
atb::Tensor placeholder_;

at::Tensor decode_attn_mask_;
at::Tensor fia_attn_mask_;

at::Tensor at_placeholder_;

Expand Down
Loading
Loading