Skip to content

add blackwell support filter for 9.7<=cudnn<9.18.1#2775

Open
sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_fp8_determinism_check
Open

add blackwell support filter for 9.7<=cudnn<9.18.1#2775
sudhakarsingh27 wants to merge 8 commits intoNVIDIA:mainfrom
sudhakarsingh27:fix_fp8_determinism_check

Conversation

@sudhakarsingh27
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
pre-commit-ci bot and others added 2 commits March 17, 2026 23:09
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR adds a targeted determinism filter to get_attention_backend() that disables the F16_arbitrary_seqlen FusedAttention backend during training on Blackwell GPUs (compute capability ≥ 10.0) when the installed cuDNN version falls in the range [9.7.0, 9.18.1), working around a cuDNN bug that prevents deterministic execution on that hardware/version combination.

Key observations:

  • The new if block is correctly nested inside the existing if use_fused_attention and deterministic: guard (line 1060), so no extra deterministic re-check is needed.
  • The filter is scoped to is_training, consistent with the other determinism guards in the same block (lines 1070–1092), indicating the issue is backward-pass specific.
  • Both version bounds use 3-element tuples (9, 7, 0) and (9, 18, 1), consistent with the rest of the file.
  • The log message contains a typo: "FusedAtttention" (three ts) should be "FusedAttention".

Confidence Score: 4/5

  • Safe to merge — the logic is sound and consistent with existing patterns; the only issue is a minor log-message typo.
  • The change is small and well-scoped: it adds a single guard block inside an existing determinism filter, uses correct 3-tuple version comparisons, properly includes is_training, and matches the surrounding code conventions. The only defect is a cosmetic typo in a debug log string that has no runtime impact. No tests or documentation changes accompany the fix, but determinism filter changes are difficult to unit-test in isolation.
  • No files require special attention beyond the typo fix in transformer_engine/pytorch/attention/dot_product_attention/utils.py.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds a determinism filter inside the existing if use_fused_attention and deterministic: guard to disable FusedAttention (F16_arbitrary_seqlen backend) during training on Blackwell GPUs (compute capability ≥ 10.0) when cuDNN is in the range [9.7.0, 9.18.1). Logic is correct and consistent with surrounding patterns; only issue is a typo ("FusedAtttention") in the debug log message.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["get_attention_backend()"] --> B{"use_fused_attention\nAND deterministic?"}
    B -- No --> Z["Continue to flash-attn selection"]
    B -- Yes --> C{"softmax_type\n≠ 'vanilla'?"}
    C -- Yes --> D["Disable FusedAttention\n(non-deterministic softmax)"]
    C -- No --> E{"backend == FP8\nAND is_training\nAND (sm < 9.0 OR cuDNN < 9.19.0)?"}
    E -- Yes --> F["Disable FusedAttention\n(FP8 determinism)"]
    E -- No --> G{"backend == F16_arbitrary_seqlen\nAND is_training\nAND (sm < 9.0 OR bias_grad OR cuDNN < 8.9.5)?"}
    G -- Yes --> H["Disable FusedAttention\n(F16 legacy determinism)"]
    G -- No --> I{"backend == F16_arbitrary_seqlen\nAND is_training\nAND 9.7.0 ≤ cuDNN < 9.18.1\nAND sm ≥ 10.0 (Blackwell)?"}
    I -- Yes --> J["Disable FusedAttention\n🆕 Blackwell + cuDNN bug filter"]
    I -- No --> Z
    D --> Z
    F --> Z
    H --> Z
    J --> Z
Loading

Last reviewed commit: "update the error log"

Comment on lines +1093 to +1103
if (
fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
and deterministic
and (cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1))
and device_compute_capability >= (10, 0)
):
logger.debug(
"Determinism not supported on Blackwell for BF16 with 9.7 <= cuDNN < 9.18.1"
)
use_fused_attention = False
fused_attention_backend = None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing is_training guard — may incorrectly disable FusedAttention for inference

Every other determinism filter in this same if use_fused_attention and deterministic: block guards against is_training (see lines 1070–1080 for FP8 and 1081–1092 for F16_arbitrary_seqlen), conveying that those non-determinism issues are backward-pass–specific. The new Blackwell / cuDNN-version filter does not include and is_training, so it will also disable FusedAttention during deterministic inference on Blackwell GPUs with cuDNN 9.7–9.18.1.

If the cuDNN bug only manifests during training (backward pass), the filter is overly broad and will unnecessarily fall back to a slower backend during inference. If it truly affects the forward pass as well, a comment explaining that would help reviewers and future maintainers understand the deviation from the existing pattern.

Consider either:

        if (
            fused_attention_backend == FusedAttnBackend["F16_arbitrary_seqlen"]
            and is_training
            and (cudnn_version >= (9, 7) and cudnn_version < (9, 18, 1))
            and device_compute_capability >= (10, 0)
        ):

or, if inference is also affected, add a comment explaining why is_training is deliberately omitted.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

updated, check again

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants