diff --git a/third_party/xllm_atb_layers b/third_party/xllm_atb_layers index 918c03d2a..cff7744d2 160000 --- a/third_party/xllm_atb_layers +++ b/third_party/xllm_atb_layers @@ -1 +1 @@ -Subproject commit 918c03d2abc4c9996196a797aefe743863b7e0ae +Subproject commit cff7744d28d3375ff9cda35330a12caac0b2af3c diff --git a/xllm/core/common/global_flags.cpp b/xllm/core/common/global_flags.cpp index 1681eb142..5564e0425 100644 --- a/xllm/core/common/global_flags.cpp +++ b/xllm/core/common/global_flags.cpp @@ -644,6 +644,11 @@ DEFINE_string(npu_kernel_backend, DEFINE_bool(enable_intralayer_addnorm, false, "enable fused intralayer addnorm ops."); + +DEFINE_bool(enable_fia, + false, + "Whether to enable FIA when initializing the NPU qwen2 decoder " + "layer."); #endif // --- multi-step decode config --- diff --git a/xllm/core/common/global_flags.h b/xllm/core/common/global_flags.h index 67a27afa9..605f18af2 100644 --- a/xllm/core/common/global_flags.h +++ b/xllm/core/common/global_flags.h @@ -310,6 +310,8 @@ DECLARE_uint32(rec_worker_max_concurrency); DECLARE_string(npu_kernel_backend); DECLARE_bool(enable_intralayer_addnorm); + +DECLARE_bool(enable_fia); #endif DECLARE_int32(health_check_interval_ms); diff --git a/xllm/core/common/help_formatter.h b/xllm/core/common/help_formatter.h index dd5364a16..5acabb076 100644 --- a/xllm/core/common/help_formatter.h +++ b/xllm/core/common/help_formatter.h @@ -106,6 +106,7 @@ const OptionCategory kOtherOptions = { "num_request_handling_threads", "num_response_handling_threads", "prefill_scheduling_memory_usage_threshold", + "enable_fia", "use_contiguous_input_buffer", "use_audio_in_video"}}; diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp index 9caf2d676..d6036206b 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.cpp @@ -19,6 +19,7 @@ limitations under the License. #include #include +#include #include #include "common/global_flags.h" @@ -30,6 +31,10 @@ limitations under the License. namespace xllm { namespace layer { +namespace { +constexpr int64_t kFiaMaskSeqLen = 2048; +} + enum DecoderLayerTensorId : int { IN_NORM_WEIGHT = 0, // weight IN_NORM_BIAS = 1, // bias @@ -117,13 +122,13 @@ void NpuQwen2DecoderLayerImpl::param_from_args( static_cast(TransposeType::INVALID), static_cast(TransposeType::NOT_TRANSPOSE)}; param.kvQuant = false; + param.blockSize = static_cast(FLAGS_block_size); param.quantGroupSize = 0; param.rmsNormEps = args.rms_norm_eps(); param.worldSize = parallel_args.world_size(); param.numAttentionHeadsPerRank = args.n_heads() / param.worldSize; param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads(); param.enableIntraLayerAddNorm = false; - param.enableInterLayerAddNorm = false; // param.numKeyValueHeadsPerRank = args.n_kv_heads(); std::optional optionalValue = args.n_kv_heads(); param.numKeyValueHeadsPerRank = @@ -134,6 +139,7 @@ void NpuQwen2DecoderLayerImpl::param_from_args( param.rank = parallel_args.rank(); param.backend = "lccl"; param.enableLogN = false; + param.isFIA = isPrefill && FLAGS_enable_fia; } NpuQwen2DecoderLayerImpl::NpuQwen2DecoderLayerImpl(const ModelContext& context) @@ -215,10 +221,72 @@ int64_t NpuQwen2DecoderLayerImpl::init_attn_mask() { torch::Dtype dtype = prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16; decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype); + if (FLAGS_enable_fia) { + const auto fia_mask_options = + torch::TensorOptions().dtype(torch::kBool).device(device_); + fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen}, + fia_mask_options), + /*diagonal=*/1) + .contiguous(); + } return atb::NO_ERROR; } +void NpuQwen2DecoderLayerImpl::build_fia_index_tensors( + const ModelInputParams& input_params, + int64_t total_tokens) { + CHECK_GT(input_params.num_sequences, 0) + << "FIA prefill requires num_sequences > 0"; + CHECK_EQ(input_params.q_seq_lens_vec.size(), + static_cast(input_params.num_sequences)) + << "q_seq_lens_vec size mismatch for FIA prefill"; + CHECK_GT(total_tokens, 0) << "FIA prefill requires non-empty hidden states"; + + const int32_t target_seq_len = *std::max_element( + input_params.q_seq_lens_vec.begin(), input_params.q_seq_lens_vec.end()); + CHECK_GT(target_seq_len, 0) << "FIA prefill requires target_seq_len > 0"; + + std::vector padding_idx(input_params.num_sequences * target_seq_len); + std::vector unpadding_idx; + unpadding_idx.reserve(total_tokens); + + int64_t source_token_offset = 0; + for (int32_t seq_idx = 0; seq_idx < input_params.num_sequences; ++seq_idx) { + const int32_t seq_len = input_params.q_seq_lens_vec[seq_idx]; + CHECK_GT(seq_len, 0) << "FIA sequence length must be positive"; + CHECK_LE(seq_len, target_seq_len) + << "FIA sequence length must be <= target_seq_len"; + CHECK_LE(source_token_offset + seq_len, total_tokens) + << "FIA sequence lengths exceed current hidden-state tokens"; + + const int32_t padded_offset = seq_idx * target_seq_len; + const int32_t fallback_idx = static_cast(source_token_offset); + for (int32_t pos = 0; pos < target_seq_len; ++pos) { + const int32_t padded_idx = padded_offset + pos; + if (pos < seq_len) { + padding_idx[padded_idx] = + static_cast(source_token_offset + pos); + unpadding_idx.push_back(padded_idx); + } else { + padding_idx[padded_idx] = fallback_idx; + } + } + source_token_offset += seq_len; + } + + CHECK_EQ(source_token_offset, total_tokens) + << "FIA padding indices do not cover all hidden-state tokens"; + CHECK_EQ(unpadding_idx.size(), static_cast(total_tokens)) + << "FIA unpadding indices size mismatch"; + + const auto index_options = torch::TensorOptions().dtype(torch::kInt32); + fia_padding_idx_ = + torch::tensor(padding_idx, index_options).to(device_).contiguous(); + fia_unpadding_idx_ = + torch::tensor(unpadding_idx, index_options).to(device_).contiguous(); +} + int64_t NpuQwen2DecoderLayerImpl::init_node( atb_speed::Model::Node& node, atb_speed::qwen::DecoderLayerParam& param) { @@ -300,34 +368,51 @@ void NpuQwen2DecoderLayerImpl::build_node_variant_pack( KVCache& kv_cache, ModelInputParams& input_params, bool is_prefill) { + if (is_prefill) { + *prefill_param_.bs = std::max(1, input_params.num_sequences); + } + + auto* effective_attn_mask = &attn_mask; + if (is_prefill && FLAGS_enable_fia) { + effective_attn_mask = &fia_attn_mask_; + // build_fia_index_tensors(input_params, x.size(0)); // bsnd + } + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = + size_t input_offset = WEIGHT_COUNT_PER_LAYER; + node.variantPack.inTensors.at(input_offset++) = internal_tensors_; + node.variantPack.inTensors.at(input_offset++) = atb_speed::Utils::AtTensor2Tensor(cos_pos); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) = + node.variantPack.inTensors.at(input_offset++) = atb_speed::Utils::AtTensor2Tensor(sin_pos); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) = - atb_speed::Utils::AtTensor2Tensor(attn_mask); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 4) = + node.variantPack.inTensors.at(input_offset++) = + atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask); + node.variantPack.inTensors.at(input_offset++) = atb_speed::Utils::AtTensor2Tensor(kv_cache.get_k_cache()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 5) = + node.variantPack.inTensors.at(input_offset++) = atb_speed::Utils::AtTensor2Tensor(kv_cache.get_v_cache()); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) = + node.variantPack.inTensors.at(input_offset) = atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6).hostData = + node.variantPack.inTensors.at(input_offset++).hostData = input_params.kv_seq_lens_vec.data(); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7) = placeholder_; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7).hostData = + node.variantPack.inTensors.at(input_offset) = placeholder_; + node.variantPack.inTensors.at(input_offset++).hostData = placeholder_vec_.data(); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8) = placeholder_; - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9) = + node.variantPack.inTensors.at(input_offset++) = placeholder_; + node.variantPack.inTensors.at(input_offset++) = atb_speed::Utils::AtTensor2Tensor(input_params.block_tables); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10) = + node.variantPack.inTensors.at(input_offset++) = atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); - if (is_prefill && FLAGS_enable_chunked_prefill) { - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) = + if (is_prefill && FLAGS_enable_fia) { + node.variantPack.inTensors.at(input_offset++) = // bsnd padding_idx + placeholder_; + node.variantPack.inTensors.at(input_offset++) = // bsnd unpadding_idx + placeholder_; + } + if (is_prefill && (FLAGS_enable_fia || FLAGS_enable_chunked_prefill)) { + node.variantPack.inTensors.at(input_offset) = atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); - node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11).hostData = + node.variantPack.inTensors.at(input_offset).hostData = input_params.q_seq_lens_vec.data(); } diff --git a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h index 052293407..0f2cdbb58 100644 --- a/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen2_decoder_layer_impl.h @@ -68,6 +68,9 @@ class NpuQwen2DecoderLayerImpl : public BaseLayer { private: void initialize_quantization_parameters(); + void build_fia_index_tensors(const ModelInputParams& input_params, + int64_t total_tokens); + void build_node_variant_pack(atb_speed::Model::Node& node, torch::Tensor& x, torch::Tensor& cos_pos, @@ -96,6 +99,9 @@ class NpuQwen2DecoderLayerImpl : public BaseLayer { atb::Tensor placeholder_; at::Tensor decode_attn_mask_; + at::Tensor fia_attn_mask_; + at::Tensor fia_padding_idx_; + at::Tensor fia_unpadding_idx_; at::Tensor at_placeholder_; diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp index adc146b08..6f8c8c26a 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.cpp @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include #include "common/global_flags.h" @@ -30,6 +31,10 @@ limitations under the License. namespace xllm { namespace layer { +namespace { +constexpr int64_t kFiaMaskSeqLen = 2048; +} + const uint64_t WEIGHT_COUNT_PER_LAYER = 56; void NpuQwen3DecoderLayerImpl::param_from_args( @@ -48,6 +53,8 @@ void NpuQwen3DecoderLayerImpl::param_from_args( param.rmsnormQKNorm = true; param.isPrefill = isPrefill; param.isBF16 = args.dtype() == "bfloat16"; + param.isFIA = isPrefill && FLAGS_enable_fia; + param.blockSize = static_cast(FLAGS_block_size); param.enableSplitFuse = FLAGS_enable_chunked_prefill && isPrefill; param.loraEnableGMM = false; param.enableXattention = is_rec_multi_round_mode(); @@ -197,6 +204,14 @@ int64_t NpuQwen3DecoderLayerImpl::init_attn_mask() { torch::Dtype dtype = prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16; decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype); + if (prefill_param_.isFIA) { + const auto fia_mask_options = + torch::TensorOptions().dtype(torch::kBool).device(device_); + fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen}, + fia_mask_options), + /*diagonal=*/1) + .contiguous(); + } return atb::NO_ERROR; } @@ -279,6 +294,15 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( ModelInputParams& input_params, bool is_prefill, int node_id) { + if (is_prefill) { + *prefill_param_.bs = std::max(1, input_params.num_sequences); + } + + auto* effective_attn_mask = &attn_mask; + if (is_prefill && prefill_param_.isFIA) { + effective_attn_mask = &fia_attn_mask_; + } + internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = @@ -286,7 +310,7 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) = atb_speed::Utils::AtTensor2Tensor(sin_pos); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) = - atb_speed::Utils::AtTensor2Tensor(attn_mask); + atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) = atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6).hostData = @@ -334,8 +358,8 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); } - if (is_prefill && - (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) { + if (is_prefill && (prefill_param_.isFIA || FLAGS_enable_chunked_prefill || + FLAGS_enable_prefix_cache)) { node.variantPack.inTensors.at(input_idx++) = atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); node.variantPack.inTensors.at(input_idx - 1).hostData = diff --git a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h index 98609cff0..fb8c1be8d 100644 --- a/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_decoder_layer_impl.h @@ -98,6 +98,7 @@ class NpuQwen3DecoderLayerImpl : public BaseLayer { atb::Tensor placeholder_; at::Tensor decode_attn_mask_; + at::Tensor fia_attn_mask_; at::Tensor at_placeholder_; diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp index 0d58d2ca8..1ae3d01cb 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.cpp @@ -17,6 +17,7 @@ limitations under the License. #include +#include #include #include "common/global_flags.h" @@ -24,6 +25,10 @@ limitations under the License. namespace xllm { namespace layer { +namespace { +constexpr int64_t kFiaMaskSeqLen = 2048; +} + static const uint64_t WEIGHT_COUNT_PER_LAYER = 55; NpuQwen3MoeDecoderLayerImpl::NpuQwen3MoeDecoderLayerImpl( @@ -102,6 +107,8 @@ void NpuQwen3MoeDecoderLayerImpl::initialize_basic_parameters( param.isBF16 = args.dtype() == "bfloat16"; param.enableSwiGLU = true; param.isPrefill = is_prefill; + param.isFIA = is_prefill && FLAGS_enable_fia; + param.blockSize = static_cast(FLAGS_block_size); // prefill only feature param.enableLcoc = is_prefill; // false; @@ -271,6 +278,7 @@ void NpuQwen3MoeDecoderLayerImpl::merge_loaded_weights() { } int64_t NpuQwen3MoeDecoderLayerImpl::init_layer() { + CHECK_OPERATION_STATUS_RETURN(init_attn_mask()); name_ = "qwen3_moe_decoder_layer " + std::to_string(layer_id_); model_name_ = "Qwen3_Moe"; CHECK_OPERATION_STATUS_RETURN(init_node(prefill_node_, prefill_param_)); @@ -279,6 +287,19 @@ int64_t NpuQwen3MoeDecoderLayerImpl::init_layer() { return atb::NO_ERROR; } +int64_t NpuQwen3MoeDecoderLayerImpl::init_attn_mask() { + if (prefill_param_.isFIA) { + const auto fia_mask_options = + torch::TensorOptions().dtype(torch::kBool).device(device_); + fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen}, + fia_mask_options), + /*diagonal=*/1) + .contiguous(); + } + + return atb::NO_ERROR; +} + int64_t NpuQwen3MoeDecoderLayerImpl::init_node( atb_speed::Model::Node& node, atb_speed::qwen::MoeDecoderLayerParam& param) { @@ -357,6 +378,15 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( KVCache& kv_cache, const ModelInputParams& input_params, bool is_prefill) { + if (is_prefill) { + *prefill_param_.bs = std::max(1, input_params.num_sequences); + } + + auto* effective_attn_mask = &attn_mask; + if (is_prefill && prefill_param_.isFIA) { + effective_attn_mask = &fia_attn_mask_; + } + internal_tensor_ = atb_speed::Utils::AtTensor2Tensor(x); int32_t input_idx = 0; auto& dp_ep_padding = input_params.dp_ep_padding_data; @@ -376,7 +406,7 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) = atb_speed::Utils::AtTensor2Tensor(sin_pos); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7) = - atb_speed::Utils::AtTensor2Tensor(attn_mask); + atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8) = atb_speed::Utils::AtTensor2Tensor(kv_cache.get_k_cache()); node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9) = @@ -420,8 +450,8 @@ void NpuQwen3MoeDecoderLayerImpl::build_node_variant_pack( } input_idx = WEIGHT_COUNT_PER_LAYER + 16; - if (is_prefill && - (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) { + if (is_prefill && (prefill_param_.isFIA || FLAGS_enable_chunked_prefill || + FLAGS_enable_prefix_cache)) { node.variantPack.inTensors.at(input_idx++) = atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); node.variantPack.inTensors.at(input_idx - 1).hostData = diff --git a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h index ed96af006..3dfc754b6 100644 --- a/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h +++ b/xllm/core/layers/npu/npu_qwen3_moe_decoder_layer_impl.h @@ -94,6 +94,8 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { void initialize_quantization_parameters( atb_speed::qwen::MoeDecoderLayerParam& param); + int64_t init_attn_mask(); + int64_t init_node(atb_speed::Model::Node& node, atb_speed::qwen::MoeDecoderLayerParam& param); @@ -140,6 +142,7 @@ class NpuQwen3MoeDecoderLayerImpl : public BaseLayer { torch::Tensor slot_tensor_placeholder_; torch::Tensor int_tensor_placeholder_; torch::Tensor decode_attn_mask_; + torch::Tensor fia_attn_mask_; torch::Tensor expert_group_; torch::Tensor quant_add_norm_scaling_; torch::Tensor quant_add_norm_offset_;