Conversation
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/utils.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
Signed-off-by: Xin Yao <xiny@nvidia.com>
Greptile SummaryThis PR integrates Flash Attention v4 ( Key observations:
Confidence Score: 4/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[get_attention_backend] --> B{FA4 installed\nv4_is_installed?}
B -- No --> C[use_flash_attention_4 = env default\nCleared later at line ~1268]
B -- Yes --> D[Apply FA4 filters]
D --> D1{compute cap < 8.0?} -- Yes --> D1x[disable FA4]
D --> D2{dtype not fp16/bf16?} -- Yes --> D2x[disable FA4]
D --> D3{FP8 DPA?} -- Yes --> D3x[disable FA4]
D --> D4{inference_params\nKV cache?} -- Yes --> D4x[disable FA4]
D --> D5{head dims\nunsupported?} -- Yes --> D5x[disable FA4]
D --> D6{SM100 MLA\nbug?} -- Yes --> D6x[disable FA4 training]
D --> D7{dropout != 0?} -- Yes --> D7x[disable FA4]
D --> D8{context parallel?} -- Yes --> D8x[disable FA4]
D --> D9{ALiBi?} -- Yes --> D9x[disable FA4]
D1 & D2 & D3 & D4 & D5 & D6 & D7 & D8 & D9 --> E{FA4 still\nenabled?}
E -- Yes --> F[flash_attention_backend =\nfa4_version]
E -- No --> G{FA3 enabled\nfa3_version in\n3.0.0b..4.0.0?}
G -- Yes --> H[flash_attention_backend =\nfa3_version]
G -- No --> I[flash_attention_backend =\nFA2 version]
F & H & I --> J[FlashAttention.forward]
J --> K{flash_attention_backend\n> 4.0.0b?}
K -- Yes --> L[use_flash_attn_4 = True\ndispatch to\nflash_attn_func_v4 /\nflash_attn_varlen_func_v4]
K -- No --> M{3.0.0b < backend\n< 4.0.0?}
M -- Yes --> N[use_flash_attn_3 = True\ndispatch to FA3]
M -- No --> O[dispatch to FA2]
Last reviewed commit: "fa4 support" |
transformer_engine/pytorch/attention/dot_product_attention/backends.py
Outdated
Show resolved
Hide resolved
| output = func( | ||
| query_layer, | ||
| key_layer, | ||
| value_layer, | ||
| softmax_scale=self.softmax_scale, | ||
| causal="causal" in attn_mask_type, |
There was a problem hiding this comment.
causal_bottom_right treated identically to causal for FA4
causal="causal" in attn_mask_type evaluates to True for both "causal" and "causal_bottom_right". If FA4's flash_attn_func supports a separate bottom-right diagonal alignment flag (similar to how cuDNN fused attention distinguishes the two), passing only causal=True would produce incorrect results for causal_bottom_right configs.
This is consistent with the existing FA2 path, but since fa4_mask_causal_br is explicitly added as a test case, it is worth verifying that the FA4 causal parameter correctly implements both variants, or adding a dedicated causal_bottom_right kwarg if the FA4 API exposes one.
1af4a5c to
0708391
Compare
Signed-off-by: Xin Yao <xiny@nvidia.com>
|
/te-ci pytorch |
Description
Need help to install
flash-attn-4in the CI container to enable FA4 tests.Type of change
Checklist: