Skip to content

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728

Merged
cyanguwa merged 11 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error
Apr 23, 2026
Merged

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728
cyanguwa merged 11 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error

Conversation

@DmCarpe93
Copy link
Copy Markdown
Contributor

@DmCarpe93 DmCarpe93 commented Mar 3, 2026

Description

Fixed an issue where the cu_seqlen tensor was incorrectly retrieved from the cache.

  • Currently, only (batch_size, max_seqlen) were used as the cache key when retrieving cu_seqlens.
  • This coud result in error especially for Knowledge Distillation training, because teacher and student model can be run on same node.
    • When teacher model run first, cu_seqlens tensor would be created and cached.
    • After that, when student model trains on the same node, the cached cu_seqlens tensor would be used if same (batch_size, max_seqlen) is used.
    • Since cached cu_seqlens tensor from teacher model could have different inference mode and device, it could result in error.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • The cache key for retrieving cu_seqlens was updated from (batch_size, max_seqlen) to include both the device and inference mode.
  • Added testcases for cu_seqlens cache.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Copy Markdown
Contributor

greptile-apps Bot commented Mar 3, 2026

Greptile Summary

This PR fixes a cache-key collision in get_full_cu_seqlens where only (batch_size, max_seqlen) was used as the key, causing tensors created on one device or under inference mode to be silently reused by callers on a different device or in a different autograd mode (e.g. teacher vs. student in Knowledge Distillation). The fix adds device and torch.is_inference_mode_enabled() to the key, and two focused pytest cases validate both isolation scenarios.

Confidence Score: 5/5

Safe to merge — the fix is minimal, targeted, and well-tested with no regressions introduced.

The change is a one-liner key extension with clear semantics. torch.device is hashable and comparable by value, and torch.is_inference_mode_enabled() is a stable API. The two new tests cover exactly the described failure scenarios. No P0 or P1 issues were found.

No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Extends the get_full_cu_seqlens cache key from (batch_size, max_seqlen) to (batch_size, max_seqlen, device, is_inference), correctly isolating cached tensors across devices and inference modes.
tests/pytorch/attention/test_cu_seqlens_cache.py New test file covering both multi-device isolation and inference-vs-training isolation for the cu_seqlens cache; uses an autouse fixture to clear the cache before/after each test.

Reviews (10): Last reviewed commit: "Merge branch 'main' into fix/get_full_cu..." | Re-trigger Greptile

Copy link
Copy Markdown
Contributor

@greptile-apps greptile-apps Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa March 3, 2026 18:54
@DmCarpe93
Copy link
Copy Markdown
Contributor Author

@cyanguwa When you have a moment, could you please take a look at this PR? Thanks:)

@DmCarpe93
Copy link
Copy Markdown
Contributor Author

@cyanguwa This PR is pretty straightforward. Would you mind taking a quick look? Thank you:)

@DmCarpe93
Copy link
Copy Markdown
Contributor Author

DmCarpe93 commented Apr 1, 2026

@cyanguwa Hi:) could you look into this PR? thank you.

@DmCarpe93
Copy link
Copy Markdown
Contributor Author

@ptrendx The review hasn’t been progressing—would it be possible to change the reviewer?
The same issue keeps occurring, and while we can work around it by modifying the training script used by our team, it’s inconvenient to apply this workaround every time.
It would be great if the fix could be properly reviewed and merged.

@ptrendx ptrendx added the community-contribution PRs from external contributor outside the core maintainers, representing community-driven work. label Apr 21, 2026
@cyanguwa
Copy link
Copy Markdown
Collaborator

/te-ci torch L1

Copy link
Copy Markdown
Collaborator

@cyanguwa cyanguwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR and sorry about the delay in reviewing! I'll run the CI and merge it. Will make another small PR to properly integrate the new test to our qa/ scripts later.

@cyanguwa cyanguwa merged commit ab60f4c into NVIDIA:main Apr 23, 2026
46 of 53 checks passed
YigongQin pushed a commit to YigongQin/TransformerEngine that referenced this pull request Apr 23, 2026
…NVIDIA#2728)

* fix: scope get_full_cu_seqlens cache key by device and inference mode

Signed-off-by: Dongmin Ra <dongmin.ra@navercorp.com>

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

---------

Signed-off-by: Dongmin Ra <dongmin.ra@navercorp.com>
Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Charlene Yang <8636796+cyanguwa@users.noreply.github.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

2.16.0 community-contribution PRs from external contributor outside the core maintainers, representing community-driven work.

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants