Skip to content

Commit c8ed610

Browse files
authored
[NPU] fix llama_infer, cherry-pick #1324 (#1326)
1 parent 4a07d97 commit c8ed610

File tree

2 files changed

+4
-8
lines changed

2 files changed

+4
-8
lines changed

backends/npu/custom_op/llama_infer/atb_ops/fused_blha_layer_op_utils.cc

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -587,16 +587,14 @@ void FusedBlhaGlobalVar::update_out_encoder(const phi::CustomContext &dev_ctx,
587587
for (auto i = 0; i < batch_size; ++i) {
588588
if (seqlens_encoder[i] > 0) {
589589
in_offset += seqlens_encoder[i] * emb_dim;
590+
out_offset = i * emb_dim;
590591
ACL_CHECK(aclrtMemcpyAsync(
591592
out_data + out_offset * sizeof(phi::float16),
592593
emb_dim * sizeof(phi::float16),
593594
in_data + (in_offset - emb_dim) * sizeof(phi::float16),
594595
emb_dim * sizeof(phi::float16),
595596
ACL_MEMCPY_DEVICE_TO_DEVICE,
596597
reinterpret_cast<aclrtStream>(dev_ctx.stream())));
597-
out_offset += emb_dim;
598-
} else if (seqlens_decoder[i] > 0) {
599-
out_offset += emb_dim;
600598
}
601599
}
602600
}
@@ -622,9 +620,8 @@ void FusedBlhaGlobalVar::update_out_decoder(const phi::CustomContext &dev_ctx,
622620

623621
int64_t in_offset = 0, out_offset = 0;
624622
for (auto i = 0; i < batch_size; ++i) {
625-
if (seqlens_encoder[i] > 0) {
626-
out_offset += emb_dim;
627-
} else if (seqlens_decoder[i] > 0) {
623+
if (seqlens_decoder[i] > 0) {
624+
out_offset = i * emb_dim;
628625
ACL_CHECK(
629626
aclrtMemcpyAsync(out_data + out_offset * sizeof(phi::float16),
630627
emb_dim * sizeof(phi::float16),
@@ -633,7 +630,6 @@ void FusedBlhaGlobalVar::update_out_decoder(const phi::CustomContext &dev_ctx,
633630
ACL_MEMCPY_DEVICE_TO_DEVICE,
634631
reinterpret_cast<aclrtStream>(dev_ctx.stream())));
635632
in_offset += emb_dim;
636-
out_offset += emb_dim;
637633
}
638634
}
639635
}

backends/npu/custom_op/llama_infer/atb_ops/remove_padding_op.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ std::vector<paddle::Tensor> RemovePaddingOp(const paddle::Tensor& x,
5858
paddle::experimental::DeviceContextPool::Instance().Get(place));
5959

6060
auto x_shape = x.shape();
61-
const int bsz = x_shape[0];
61+
const int bsz = seqlen.numel();
6262
const int padding_len = x_shape[1];
6363

6464
auto seqlen_host = seqlen.copy_to(paddle::CPUPlace(), true);

0 commit comments

Comments
 (0)