Skip to content

[PyTorch] Add FA4 Support#2432

Open
yaox12 wants to merge 2 commits intoNVIDIA:mainfrom
yaox12:xiny/fa4
Open

[PyTorch] Add FA4 Support#2432
yaox12 wants to merge 2 commits intoNVIDIA:mainfrom
yaox12:xiny/fa4

Conversation

@yaox12
Copy link
Member

@yaox12 yaox12 commented Nov 28, 2025

Description

  • Add FA4 support
  • Add tests

Need help to install flash-attn-4 in the CI container to enable FA4 tests.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12 yaox12 marked this pull request as ready for review March 19, 2026 22:15
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 19, 2026

Greptile Summary

This PR integrates Flash Attention v4 (flash-attn-4) as a new attention backend alongside the existing FA2/FA3 paths. It adds FA4-specific backend detection and filtering in get_attention_backend (compute capability, dtype, FP8, KV cache, head-dim, dropout, context parallelism, ALiBi guards), a new FA4 forward dispatch in FlashAttention.forward(), and comprehensive tests covering base configs, MLA (mixed head dims), sliding window, variable-length sequences, and attention mask types.

Key observations:

  • The FA4/FA3 version check in backends.py is correctly made mutually exclusive with a bounded range check (3.0.0b < version < 4.0.0 for FA3), addressing potential dispatch ambiguity.
  • set_flash_attention_4_params only sets v4_is_installed = True without validating a minimum version (unlike FA3's v3_0_0_beta check), which can silently fall back to FA2 when a pre-release FA4 beta is installed.
  • After FA4 is explicitly disabled for context parallelism, the immediately following if context_parallel condition still includes use_flash_attention_4 and FlashAttentionUtils.v4_is_installed references that are unreachable dead code.
  • The SM100 backward kernel MLA bug workaround (dK_reduce_ncol misalignment) is well-documented and has a dedicated test case (fa4_mla_4).

Confidence Score: 4/5

  • This PR is safe to merge with minor code quality issues; no functional correctness bugs in mainline paths.
  • The FA4 dispatch logic is correct and well-guarded, the FA3/FA4 mutual exclusion fix prevents double-dispatch, and the test coverage is thorough. The two issues found are a dead-code reference in the CP block (harmless but misleading) and a missing version validation in set_flash_attention_4_params (only impacts pre-release FA4 installs, not production). The PR is not yet fully testable in CI (FA4 not installed), so runtime validation is deferred.
  • transformer_engine/pytorch/attention/dot_product_attention/utils.py — dead code in CP condition and missing version validation in set_flash_attention_4_params.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds FA4 import block, version detection, and dispatch logic in FlashAttention.forward(). FA4/FA3 version check is now correctly made mutually exclusive (bounds check added for FA3 range). The FA4 dispatch correctly prioritizes FA4 over FA3/FA2. No critical issues in the dispatch path; the causal_bottom_right handling is acknowledged in existing review threads.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds FA4 filtering throughout get_attention_backend(): compute capability, dtype, FP8, KV cache, head dim, padding, dropout, CP, ALiBi, and arbitrary mask guards. Two issues: (1) dead code — use_flash_attention_4 in the second context_parallel condition is always False after the preceding FA4/CP disable block; (2) set_flash_attention_4_params does not validate the FA4 version, risking silent fallback to FA2 for pre-release FA4 installs.
tests/pytorch/attention/test_attention.py Adds six new FA4 test groups (base, MLA, sliding window, varlen, mask types). Tests include a regression case for the SM100 backward kernel MLA bug (fa4_mla_4). sys.path prepend change ensures the local utils module is resolved before any installed package. No issues found.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[get_attention_backend] --> B{FA4 installed\nv4_is_installed?}
    B -- No --> C[use_flash_attention_4 = env default\nCleared later at line ~1268]
    B -- Yes --> D[Apply FA4 filters]
    D --> D1{compute cap < 8.0?} -- Yes --> D1x[disable FA4]
    D --> D2{dtype not fp16/bf16?} -- Yes --> D2x[disable FA4]
    D --> D3{FP8 DPA?} -- Yes --> D3x[disable FA4]
    D --> D4{inference_params\nKV cache?} -- Yes --> D4x[disable FA4]
    D --> D5{head dims\nunsupported?} -- Yes --> D5x[disable FA4]
    D --> D6{SM100 MLA\nbug?} -- Yes --> D6x[disable FA4 training]
    D --> D7{dropout != 0?} -- Yes --> D7x[disable FA4]
    D --> D8{context parallel?} -- Yes --> D8x[disable FA4]
    D --> D9{ALiBi?} -- Yes --> D9x[disable FA4]

    D1 & D2 & D3 & D4 & D5 & D6 & D7 & D8 & D9 --> E{FA4 still\nenabled?}
    E -- Yes --> F[flash_attention_backend =\nfa4_version]
    E -- No --> G{FA3 enabled\nfa3_version in\n3.0.0b..4.0.0?}
    G -- Yes --> H[flash_attention_backend =\nfa3_version]
    G -- No --> I[flash_attention_backend =\nFA2 version]

    F & H & I --> J[FlashAttention.forward]
    J --> K{flash_attention_backend\n> 4.0.0b?}
    K -- Yes --> L[use_flash_attn_4 = True\ndispatch to\nflash_attn_func_v4 /\nflash_attn_varlen_func_v4]
    K -- No --> M{3.0.0b < backend\n< 4.0.0?}
    M -- Yes --> N[use_flash_attn_3 = True\ndispatch to FA3]
    M -- No --> O[dispatch to FA2]
Loading

Last reviewed commit: "fa4 support"

Comment on lines +1037 to +1042
output = func(
query_layer,
key_layer,
value_layer,
softmax_scale=self.softmax_scale,
causal="causal" in attn_mask_type,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 causal_bottom_right treated identically to causal for FA4

causal="causal" in attn_mask_type evaluates to True for both "causal" and "causal_bottom_right". If FA4's flash_attn_func supports a separate bottom-right diagonal alignment flag (similar to how cuDNN fused attention distinguishes the two), passing only causal=True would produce incorrect results for causal_bottom_right configs.

This is consistent with the existing FA2 path, but since fa4_mask_causal_br is explicitly added as a test case, it is worth verifying that the FA4 causal parameter correctly implements both variants, or adding a dedicated causal_bottom_right kwarg if the FA4 API exposes one.

@yaox12 yaox12 force-pushed the xiny/fa4 branch 3 times, most recently from 1af4a5c to 0708391 Compare March 19, 2026 22:51
Signed-off-by: Xin Yao <xiny@nvidia.com>
@yaox12
Copy link
Member Author

yaox12 commented Mar 19, 2026

/te-ci pytorch

@KshitijLakhani KshitijLakhani requested a review from mk-61 March 19, 2026 23:28
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant