@@ -587,16 +587,14 @@ void FusedBlhaGlobalVar::update_out_encoder(const phi::CustomContext &dev_ctx,
587587 for (auto i = 0 ; i < batch_size; ++i) {
588588 if (seqlens_encoder[i] > 0 ) {
589589 in_offset += seqlens_encoder[i] * emb_dim;
590+ out_offset = i * emb_dim;
590591 ACL_CHECK (aclrtMemcpyAsync (
591592 out_data + out_offset * sizeof (phi::float16),
592593 emb_dim * sizeof (phi::float16),
593594 in_data + (in_offset - emb_dim) * sizeof (phi::float16),
594595 emb_dim * sizeof (phi::float16),
595596 ACL_MEMCPY_DEVICE_TO_DEVICE,
596597 reinterpret_cast <aclrtStream>(dev_ctx.stream ())));
597- out_offset += emb_dim;
598- } else if (seqlens_decoder[i] > 0 ) {
599- out_offset += emb_dim;
600598 }
601599 }
602600}
@@ -622,9 +620,8 @@ void FusedBlhaGlobalVar::update_out_decoder(const phi::CustomContext &dev_ctx,
622620
623621 int64_t in_offset = 0 , out_offset = 0 ;
624622 for (auto i = 0 ; i < batch_size; ++i) {
625- if (seqlens_encoder[i] > 0 ) {
626- out_offset += emb_dim;
627- } else if (seqlens_decoder[i] > 0 ) {
623+ if (seqlens_decoder[i] > 0 ) {
624+ out_offset = i * emb_dim;
628625 ACL_CHECK (
629626 aclrtMemcpyAsync (out_data + out_offset * sizeof (phi::float16),
630627 emb_dim * sizeof (phi::float16),
@@ -633,7 +630,6 @@ void FusedBlhaGlobalVar::update_out_decoder(const phi::CustomContext &dev_ctx,
633630 ACL_MEMCPY_DEVICE_TO_DEVICE,
634631 reinterpret_cast <aclrtStream>(dev_ctx.stream ())));
635632 in_offset += emb_dim;
636- out_offset += emb_dim;
637633 }
638634 }
639635}
0 commit comments