-
Notifications
You must be signed in to change notification settings - Fork 166
feat: support FIA for qwen model on npu device. #1147
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
cd8edf2
6291248
e07f3e7
cca9d14
9c84299
b9e09a6
eee739e
25d8a9e
5ad1ae5
7ee2dc8
7444c3b
4a0d40b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -19,6 +19,7 @@ limitations under the License. | |
| #include <glog/logging.h> | ||
| #include <mstx/ms_tools_ext.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <map> | ||
|
|
||
| #include "common/global_flags.h" | ||
|
|
@@ -30,6 +31,10 @@ limitations under the License. | |
| namespace xllm { | ||
| namespace layer { | ||
|
|
||
| namespace { | ||
| constexpr int64_t kFiaMaskSeqLen = 2048; | ||
| } | ||
|
|
||
| enum DecoderLayerTensorId : int { | ||
| IN_NORM_WEIGHT = 0, // weight | ||
| IN_NORM_BIAS = 1, // bias | ||
|
|
@@ -117,13 +122,13 @@ void NpuQwen2DecoderLayerImpl::param_from_args( | |
| static_cast<int>(TransposeType::INVALID), | ||
| static_cast<int>(TransposeType::NOT_TRANSPOSE)}; | ||
| param.kvQuant = false; | ||
| param.blockSize = static_cast<int64_t>(FLAGS_block_size); | ||
| param.quantGroupSize = 0; | ||
| param.rmsNormEps = args.rms_norm_eps(); | ||
| param.worldSize = parallel_args.world_size(); | ||
| param.numAttentionHeadsPerRank = args.n_heads() / param.worldSize; | ||
| param.hiddenSizePerAttentionHead = args.hidden_size() / args.n_heads(); | ||
| param.enableIntraLayerAddNorm = false; | ||
| param.enableInterLayerAddNorm = false; | ||
| // param.numKeyValueHeadsPerRank = args.n_kv_heads(); | ||
| std::optional<long int> optionalValue = args.n_kv_heads(); | ||
| param.numKeyValueHeadsPerRank = | ||
|
|
@@ -134,6 +139,7 @@ void NpuQwen2DecoderLayerImpl::param_from_args( | |
| param.rank = parallel_args.rank(); | ||
| param.backend = "lccl"; | ||
| param.enableLogN = false; | ||
| param.isFIA = isPrefill && FLAGS_enable_fia; | ||
| } | ||
|
|
||
| NpuQwen2DecoderLayerImpl::NpuQwen2DecoderLayerImpl(const ModelContext& context) | ||
|
|
@@ -215,10 +221,72 @@ int64_t NpuQwen2DecoderLayerImpl::init_attn_mask() { | |
| torch::Dtype dtype = | ||
| prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16; | ||
| decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype); | ||
| if (FLAGS_enable_fia) { | ||
| const auto fia_mask_options = | ||
| torch::TensorOptions().dtype(torch::kBool).device(device_); | ||
| fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen}, | ||
| fia_mask_options), | ||
| /*diagonal=*/1) | ||
| .contiguous(); | ||
| } | ||
|
|
||
| return atb::NO_ERROR; | ||
| } | ||
|
|
||
| void NpuQwen2DecoderLayerImpl::build_fia_index_tensors( | ||
| const ModelInputParams& input_params, | ||
| int64_t total_tokens) { | ||
| CHECK_GT(input_params.num_sequences, 0) | ||
| << "FIA prefill requires num_sequences > 0"; | ||
| CHECK_EQ(input_params.q_seq_lens_vec.size(), | ||
| static_cast<size_t>(input_params.num_sequences)) | ||
| << "q_seq_lens_vec size mismatch for FIA prefill"; | ||
| CHECK_GT(total_tokens, 0) << "FIA prefill requires non-empty hidden states"; | ||
|
|
||
| const int32_t target_seq_len = *std::max_element( | ||
| input_params.q_seq_lens_vec.begin(), input_params.q_seq_lens_vec.end()); | ||
| CHECK_GT(target_seq_len, 0) << "FIA prefill requires target_seq_len > 0"; | ||
|
|
||
| std::vector<int32_t> padding_idx(input_params.num_sequences * target_seq_len); | ||
| std::vector<int32_t> unpadding_idx; | ||
| unpadding_idx.reserve(total_tokens); | ||
|
|
||
| int64_t source_token_offset = 0; | ||
| for (int32_t seq_idx = 0; seq_idx < input_params.num_sequences; ++seq_idx) { | ||
| const int32_t seq_len = input_params.q_seq_lens_vec[seq_idx]; | ||
| CHECK_GT(seq_len, 0) << "FIA sequence length must be positive"; | ||
| CHECK_LE(seq_len, target_seq_len) | ||
| << "FIA sequence length must be <= target_seq_len"; | ||
| CHECK_LE(source_token_offset + seq_len, total_tokens) | ||
| << "FIA sequence lengths exceed current hidden-state tokens"; | ||
|
|
||
| const int32_t padded_offset = seq_idx * target_seq_len; | ||
| const int32_t fallback_idx = static_cast<int32_t>(source_token_offset); | ||
| for (int32_t pos = 0; pos < target_seq_len; ++pos) { | ||
| const int32_t padded_idx = padded_offset + pos; | ||
| if (pos < seq_len) { | ||
| padding_idx[padded_idx] = | ||
| static_cast<int32_t>(source_token_offset + pos); | ||
| unpadding_idx.push_back(padded_idx); | ||
| } else { | ||
| padding_idx[padded_idx] = fallback_idx; | ||
| } | ||
| } | ||
| source_token_offset += seq_len; | ||
| } | ||
|
|
||
| CHECK_EQ(source_token_offset, total_tokens) | ||
| << "FIA padding indices do not cover all hidden-state tokens"; | ||
| CHECK_EQ(unpadding_idx.size(), static_cast<size_t>(total_tokens)) | ||
| << "FIA unpadding indices size mismatch"; | ||
|
|
||
| const auto index_options = torch::TensorOptions().dtype(torch::kInt32); | ||
| fia_padding_idx_ = | ||
| torch::tensor(padding_idx, index_options).to(device_).contiguous(); | ||
| fia_unpadding_idx_ = | ||
| torch::tensor(unpadding_idx, index_options).to(device_).contiguous(); | ||
| } | ||
|
|
||
| int64_t NpuQwen2DecoderLayerImpl::init_node( | ||
| atb_speed::Model::Node& node, | ||
| atb_speed::qwen::DecoderLayerParam& param) { | ||
|
|
@@ -300,34 +368,51 @@ void NpuQwen2DecoderLayerImpl::build_node_variant_pack( | |
| KVCache& kv_cache, | ||
| ModelInputParams& input_params, | ||
| bool is_prefill) { | ||
| if (is_prefill) { | ||
| *prefill_param_.bs = std::max(1, input_params.num_sequences); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } | ||
|
|
||
| auto* effective_attn_mask = &attn_mask; | ||
| if (is_prefill && FLAGS_enable_fia) { | ||
| effective_attn_mask = &fia_attn_mask_; | ||
| // build_fia_index_tensors(input_params, x.size(0)); // bsnd | ||
| } | ||
|
|
||
| internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = | ||
| size_t input_offset = WEIGHT_COUNT_PER_LAYER; | ||
| node.variantPack.inTensors.at(input_offset++) = internal_tensors_; | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(cos_pos); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) = | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(sin_pos); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) = | ||
| atb_speed::Utils::AtTensor2Tensor(attn_mask); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 4) = | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask); | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(kv_cache.get_k_cache()); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 5) = | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(kv_cache.get_v_cache()); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) = | ||
| node.variantPack.inTensors.at(input_offset) = | ||
| atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6).hostData = | ||
| node.variantPack.inTensors.at(input_offset++).hostData = | ||
| input_params.kv_seq_lens_vec.data(); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7) = placeholder_; | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 7).hostData = | ||
| node.variantPack.inTensors.at(input_offset) = placeholder_; | ||
| node.variantPack.inTensors.at(input_offset++).hostData = | ||
| placeholder_vec_.data(); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 8) = placeholder_; | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 9) = | ||
| node.variantPack.inTensors.at(input_offset++) = placeholder_; | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(input_params.block_tables); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 10) = | ||
| node.variantPack.inTensors.at(input_offset++) = | ||
| atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); | ||
| if (is_prefill && FLAGS_enable_chunked_prefill) { | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11) = | ||
| if (is_prefill && FLAGS_enable_fia) { | ||
| node.variantPack.inTensors.at(input_offset++) = // bsnd padding_idx | ||
| placeholder_; | ||
| node.variantPack.inTensors.at(input_offset++) = // bsnd unpadding_idx | ||
| placeholder_; | ||
| } | ||
|
Comment on lines
+406
to
+411
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. The FIA index tensors are not being correctly passed to the variant pack; if (is_prefill && FLAGS_enable_fia) {
build_fia_index_tensors(input_params, x.size(0));
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(fia_padding_idx_);
node.variantPack.inTensors.at(input_offset++) =
atb_speed::Utils::AtTensor2Tensor(fia_unpadding_idx_);
} |
||
| if (is_prefill && (FLAGS_enable_fia || FLAGS_enable_chunked_prefill)) { | ||
| node.variantPack.inTensors.at(input_offset) = | ||
| atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 11).hostData = | ||
| node.variantPack.inTensors.at(input_offset).hostData = | ||
| input_params.q_seq_lens_vec.data(); | ||
| } | ||
|
|
||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -18,6 +18,7 @@ limitations under the License. | |
| #include <glog/logging.h> | ||
| #include <mstx/ms_tools_ext.h> | ||
|
|
||
| #include <algorithm> | ||
| #include <map> | ||
|
|
||
| #include "common/global_flags.h" | ||
|
|
@@ -30,6 +31,10 @@ limitations under the License. | |
| namespace xllm { | ||
| namespace layer { | ||
|
|
||
| namespace { | ||
| constexpr int64_t kFiaMaskSeqLen = 2048; | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } | ||
|
|
||
| const uint64_t WEIGHT_COUNT_PER_LAYER = 56; | ||
|
|
||
| void NpuQwen3DecoderLayerImpl::param_from_args( | ||
|
|
@@ -48,6 +53,8 @@ void NpuQwen3DecoderLayerImpl::param_from_args( | |
| param.rmsnormQKNorm = true; | ||
| param.isPrefill = isPrefill; | ||
| param.isBF16 = args.dtype() == "bfloat16"; | ||
| param.isFIA = isPrefill && FLAGS_enable_fia; | ||
| param.blockSize = static_cast<int64_t>(FLAGS_block_size); | ||
| param.enableSplitFuse = FLAGS_enable_chunked_prefill && isPrefill; | ||
| param.loraEnableGMM = false; | ||
| param.enableXattention = is_rec_multi_round_mode(); | ||
|
|
@@ -197,6 +204,14 @@ int64_t NpuQwen3DecoderLayerImpl::init_attn_mask() { | |
| torch::Dtype dtype = | ||
| prefill_param_.isBF16 ? torch::kBFloat16 : torch::kFloat16; | ||
| decode_attn_mask_ = torch::zeros({1}).to(device_).to(dtype); | ||
| if (prefill_param_.isFIA) { | ||
| const auto fia_mask_options = | ||
| torch::TensorOptions().dtype(torch::kBool).device(device_); | ||
| fia_attn_mask_ = torch::triu(torch::ones({kFiaMaskSeqLen, kFiaMaskSeqLen}, | ||
| fia_mask_options), | ||
| /*diagonal=*/1) | ||
| .contiguous(); | ||
| } | ||
|
|
||
| return atb::NO_ERROR; | ||
| } | ||
|
|
@@ -279,14 +294,23 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( | |
| ModelInputParams& input_params, | ||
| bool is_prefill, | ||
| int node_id) { | ||
| if (is_prefill) { | ||
| *prefill_param_.bs = std::max(1, input_params.num_sequences); | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
| } | ||
|
|
||
| auto* effective_attn_mask = &attn_mask; | ||
| if (is_prefill && prefill_param_.isFIA) { | ||
| effective_attn_mask = &fia_attn_mask_; | ||
| } | ||
|
|
||
| internal_tensors_ = atb_speed::Utils::AtTensor2Tensor(x); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER) = internal_tensors_; | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 1) = | ||
| atb_speed::Utils::AtTensor2Tensor(cos_pos); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 2) = | ||
| atb_speed::Utils::AtTensor2Tensor(sin_pos); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 3) = | ||
| atb_speed::Utils::AtTensor2Tensor(attn_mask); | ||
| atb_speed::Utils::AtTensor2Tensor(*effective_attn_mask); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6) = | ||
| atb_speed::Utils::AtTensor2Tensor(input_params.kv_seq_lens); | ||
| node.variantPack.inTensors.at(WEIGHT_COUNT_PER_LAYER + 6).hostData = | ||
|
|
@@ -334,8 +358,8 @@ void NpuQwen3DecoderLayerImpl::build_node_variant_pack( | |
| atb_speed::Utils::AtTensor2Tensor(input_params.new_cache_slots); | ||
| } | ||
|
|
||
| if (is_prefill && | ||
| (FLAGS_enable_chunked_prefill || FLAGS_enable_prefix_cache)) { | ||
| if (is_prefill && (prefill_param_.isFIA || FLAGS_enable_chunked_prefill || | ||
| FLAGS_enable_prefix_cache)) { | ||
| node.variantPack.inTensors.at(input_idx++) = | ||
| atb_speed::Utils::AtTensor2Tensor(input_params.q_seq_lens); | ||
| node.variantPack.inTensors.at(input_idx - 1).hostData = | ||
|
|
||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hardcoding
kFiaMaskSeqLento 2048 limits the maximum sequence length supported when FIA is enabled. If the input sequence length exceeds 2048, thefia_attn_mask_(initialized at line 229) will have incorrect dimensions for the attention operation, potentially leading to incorrect results or crashes. This should be made dynamic or set to a sufficiently large value supported by the model's context window.