Skip to content

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728

Open
DmCarpe93 wants to merge 5 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error
Open

fix: scope get_full_cu_seqlens cache key by device and inference mode#2728
DmCarpe93 wants to merge 5 commits intoNVIDIA:mainfrom
DmCarpe93:fix/get_full_cu_seqlens_cache_key_error

Conversation

@DmCarpe93
Copy link

@DmCarpe93 DmCarpe93 commented Mar 3, 2026

Description

Fixed an issue where the cu_seqlen tensor was incorrectly retrieved from the cache.

  • Currently, only (batch_size, max_seqlen) were used as the cache key when retrieving cu_seqlens.
  • This coud result in error especially for Knowledge Distillation training, because teacher and student model can be run on same node.
    • When teacher model run first, cu_seqlens tensor would be created and cached.
    • After that, when student model trains on the same node, the cached cu_seqlens tensor would be used if same (batch_size, max_seqlen) is used.
    • Since cached cu_seqlens tensor from teacher model could have different inference mode and device, it could result in error.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • The cache key for retrieving cu_seqlens was updated from (batch_size, max_seqlen) to include both the device and inference mode.
  • Added testcases for cu_seqlens cache.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 3, 2026

Greptile Summary

This PR fixes a bug in get_full_cu_seqlens where the global _cu_seqlens_cache used only (batch_size, max_seqlen) as the lookup key, causing tensors created for one device or execution context (e.g., a teacher model running under torch.inference_mode()) to be incorrectly returned for a different device or execution context (e.g., a student model in training mode). The fix extends the cache key to (batch_size, max_seqlen, device, is_inference_mode).

Key changes:

  • utils.py: Cache key extended from (batch_size, max_seqlen)(batch_size, max_seqlen, device, is_inference) using torch.is_inference_mode_enabled(). torch.device objects (always returned with an explicit index from tensor.device) are hashable and compare correctly, so the dict semantics are sound.
  • test_cu_seqlens_cache.py: New test file with a proper autouse fixture that clears the global cache before/after each test. Two tests cover the two fixed dimensions: multi-device isolation and inference-vs-training-mode isolation. The training-mode test also runs backward(), directly reproducing the scenario where an inference-mode tensor (torch.int32, but still marked as an "inference tensor") would have caused an autograd error if returned from the cache.
  • Minor note: the multi-device test uses torch.no_grad() (not torch.inference_mode()), so both runs have is_inference=False; the test validates device-key isolation but does not exercise inference-mode tensors on separate devices simultaneously.

Confidence Score: 5/5

  • This PR is safe to merge — it is a narrowly scoped, correct bug fix with no regressions.
  • The change is minimal (5 lines in utils.py), the fix is logically sound (torch.device is hashable with correct equality semantics; torch.is_inference_mode_enabled() is the correct API), and the existing ONNX export bypass is preserved. The new tests reproduce the described failure scenarios. No public API is changed.
  • No files require special attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Extends the _cu_seqlens_cache key from (batch_size, max_seqlen) to (batch_size, max_seqlen, device, is_inference_mode) using torch.is_inference_mode_enabled(). The fix is minimal, correct, and does not introduce regressions. torch.device is hashable and returned consistently from tensor .device attributes (always with explicit index), so dict key semantics are sound.
tests/pytorch/attention/test_cu_seqlens_cache.py New test file with an autouse fixture that properly clears the global cache around each test. Two test cases cover device isolation and inference-vs-training-mode isolation. The multi-device test uses torch.no_grad() rather than torch.inference_mode(), which is functionally different — both runs get is_inference=False — but still validates device-scoped key separation. The inference-vs-training test is the stronger regression test for the described KD bug.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_full_cu_seqlens called"] --> B{ONNX export mode?}
    B -- Yes --> C["Skip cache, create tensor directly"]
    B -- No --> D["Read torch.is_inference_mode_enabled()"]
    D --> E["Form tuple: batch_size + max_seqlen + device + inference_flag"]
    E --> F{Tuple in cache dict?}
    F -- No --> G["Allocate cu_seqlens via torch.arange"]
    G --> H["Write to global cache"]
    H --> I["Return tensor"]
    F -- Yes --> I["Return cached tensor"]
Loading

Comments Outside Diff (1)

  1. tests/pytorch/attention/test_cu_seqlens_cache.py, line 53-58 (link)

    torch.no_grad() does not set inference mode — device isolation test uses wrong context

    Inside torch.no_grad(), torch.is_inference_mode_enabled() returns False, so both dev0 and dev1 runs will produce cache keys with is_inference=False: (2, 8, dev0, False) and (2, 8, dev1, False). The test correctly verifies device isolation, but the context manager used here (torch.no_grad()) does not reflect the scenario described in the PR (Knowledge Distillation with teacher running in torch.inference_mode()).

    This means the multi-device test does not exercise the case where one device runs in inference mode and another in training mode simultaneously. If such a scenario were to occur (e.g., teacher on cuda:1 in torch.inference_mode(), student on cuda:0 in training mode), the keys would be (2, 8, cuda:0, False) and (2, 8, cuda:1, True) — different for two reasons. The test as written only validates the device dimension of isolation. This is fine for its stated purpose, but a comment documenting the intentional use of torch.no_grad() (rather than torch.inference_mode()) would improve clarity.

    Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Last reviewed commit: 319fd26

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

2 files reviewed, no comments

Edit Code Review Agent Settings | Greptile

@ptrendx ptrendx requested a review from cyanguwa March 3, 2026 18:54
@DmCarpe93
Copy link
Author

@cyanguwa When you have a moment, could you please take a look at this PR? Thanks:)

@DmCarpe93
Copy link
Author

@cyanguwa This PR is pretty straightforward. Would you mind taking a quick look? Thank you:)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant