[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD#2596
[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD#2596sudhakarsingh27 wants to merge 11 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR enables Key changes:
One logic concern: in Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant User as User / DotProductAttention
participant FA as FlashAttention.forward
participant CP as attn_forward_func_with_cp
participant P2P as AttnFuncWithCPAndKVP2P
participant PrepQKV as cp_p2p_fwd_prepare_qkv
participant FlashFwd as cp_p2p_fwd_flash_attn (FA3)
participant A2A as AttnFuncWithCPAndQKVOA2A
User->>FA: q, k, v, cu_seqlens_q, cu_seqlens_q_padded, pad_between_seqs=True
FA->>CP: forward args (global cu_seqlens_q_padded when pad_between_seqs)
CP->>P2P: apply(cu_seqlens_q [global actual], cu_seqlens_q_padded [global padded])
P2P->>PrepQKV: get_cu_seqlens_on_cp_rank(cu_seqlens_q, cu_seqlens_q_padded, rank)
PrepQKV-->>P2P: cu_seqlens_q_per_step [local actual]
P2P->>FlashFwd: flash_attn_inputs (cu_seqlens_q_padded=GLOBAL), prepare_outputs (cu_seqlens_q_per_step=LOCAL)
Note over FlashFwd: seqused_q = per_step[1:] - per_step[:-1] (local actual)<br/>cu_seqlens_q_ = cu_seqlens_q_padded (GLOBAL ⚠️ should be LOCAL)
FlashFwd->>FlashFwd: FA3 forward with seqused_q + cu_seqlens_q_ (mismatch for batch>1)
CP->>A2A: apply(cu_seqlens_q [global actual], cu_seqlens_q_padded [global padded], pad_between_seqs)
Note over A2A: seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]<br/>fa_cu_seqlens_q = cu_seqlens_q_padded (local context, ✓)
A2A->>A2A: FA3 forward with seqused_q + padded cu_seqlens
|
| # if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k` | ||
| # in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the | ||
| # padding positions. | ||
| if pad_between_seqs: | ||
| fa_3_optional_forward_kwargs["seqused_q"] = ( | ||
| cu_seqlens_q[1:] - cu_seqlens_q[:-1] | ||
| ) | ||
| fa_3_optional_forward_kwargs["seqused_k"] = ( | ||
| cu_seqlens_kv[1:] - cu_seqlens_kv[:-1] | ||
| ) |
There was a problem hiding this comment.
style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?
## TLDR
Enable `pad_between_seqs=True` for A2A and P2P context parallelism comm types
with FlashAttention 3 and THD format. Previously `pad_between_seqs` was only
supported with FusedAttention.
## Problem
When using THD format with variable-length sequences, sequences are padded for
divisibility across CP ranks. With `pad_between_seqs=True`, the attention kernel
needs to know actual (unpadded) token counts so it doesn't compute attention over
padding tokens. FusedAttention already handled this via `cu_seqlens_q_padded`, but
FlashAttention (both FA2 and FA3) had `pad_between_seqs` hardcoded to `False` in
the CP path, and FA2 was entirely disabled for `pad_between_seqs + thd`. FA3 can
natively handle this via its `seqused_q`/`seqused_k` mechanism.
## Solution
Use FA3's `seqused_q`/`seqused_k` tensors to communicate actual token counts per
batch element. Pass `cu_seqlens_q_padded` for tensor memory layout while deriving
`seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]` from the real `cu_seqlens`.
## Changes
### context_parallel.py
- `get_fa_args()`: Add `seqused_q`/`seqused_k` parameters, pass through to FA3
forward and backward positional arg lists (replacing hardcoded `None`s).
- `cp_p2p_fwd_flash_attn()` / `cp_p2p_bwd_flash_attn()`: Accept `pad_between_seqs`,
`cu_seqlens_q_padded`, `cu_seqlens_kv_padded`. When enabled, derive `seqused`
tensors and override `cu_seqlens` to padded versions (with half-padding for
lower-triangle/upper-triangle sections).
- `AttnFuncWithCPAndKVP2P`: Thread `pad_between_seqs` and padded cu_seqlens
through all forward/backward `cp_p2p_fwd/bwd_flash_attn` call sites. Save
`ctx.pad_between_seqs` for backward.
- `AttnFuncWithCPAndQKVOA2A.forward()`: Add `pad_between_seqs` parameter. When
enabled with FA3+THD, derive `seqused` and swap `cu_seqlens` for padded versions
before calling `get_fa_args()`.
- `AttnFuncWithCPAndQKVOA2A.backward()`: Same seqused/cu_seqlens override.
Use `zeros_like` (not `empty_like`) for gradient init when `pad_between_seqs`
since FA3 skips padding positions. Add extra `None` in return tuple for the
new `pad_between_seqs` gradient slot.
- `attn_forward_func_with_cp()`: Pass `pad_between_seqs` in A2A args list.
### backends.py
- `FlashAttention.forward()`: Accept `cu_seqlens_q_padded`/`cu_seqlens_kv_padded`.
Detect `pad_between_seqs` by comparing padded vs actual cu_seqlens. Pass padded
cu_seqlens to CP path. For non-CP FA3 path, derive and pass `seqused_q`/`seqused_k`.
### dot_product_attention.py
- Pass `cu_seqlens_q_padded`/`cu_seqlens_kv_padded` through to `FlashAttention`.
### utils.py
- Only disable FA2 (not FA3) when `pad_between_seqs + thd`. FA3 handles this
natively via `seqused`.
### test_attention_with_cp.py
- Add `@pytest.mark.parametrize("pad_between_seqs", [False, True])` to flash
attention CP tests.
- Skip `pad_between_seqs=True` for non-THD formats, when FA3 is not installed,
and for `a2a+p2p` comm type (not yet supported).
### run_attention_with_cp.py
- Thread `pad_between_seqs` through `generate_input_shapes()` and `run_dpa_with_cp()`.
- When `pad_between_seqs`, set `cu_seqlens_q` to actual lengths (not just for
FusedAttention).
- Handle FA3 backward NaN at padding positions: `nan_to_num(nan=0.0)`.
- Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros
at padding slots).
- Add tensor names to NaN/Inf assertion messages for debuggability.
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs_auto
ea51821 to
e338049
Compare
for more information, see https://pre-commit.ci
|
/te-ci pytorch L2 |
…_attn_pad_bw_seqs
Enable FlashAttention backend in test_attention.py to use padded cu_seqlens and pad_between_seqs parameter, matching FusedAttention's test path. FA3 natively supports pad_between_seqs via seqused_q/seqused_k. - Group FlashAttention with FusedAttention for padded input/output handling - Pass cu_seqlens_q_padded/cu_seqlens_kv_padded for FlashAttention backend - Pass pad_between_seqs to DPA call - Add pad_between_seqs=True to parametrize with thd-only skip Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…_attn_pad_bw_seqs
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py
Outdated
Show resolved
Hide resolved
|
|
||
| pad_between_seqs = False | ||
| if qkv_format == "thd" and cu_seqlens_q_padded is not None: | ||
| pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q) |
There was a problem hiding this comment.
Can pad_between_seqs be decided ahead of time, passed by the user or something? This wouldn't be CUDA Graph-compatible right?
There was a problem hiding this comment.
This pattern exists in dpa.py as well. But yes, it's definitely redundant here
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ransformerEngine into flash_attn_pad_bw_seqs
…_attn_pad_bw_seqs
|
/te-ci pytorch L1 |
Description
TLDR
Enable
pad_between_seqs=Truefor FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previouslypad_between_seqswas only supported with FusedAttention.Problem
When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With
pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this viacu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) hadpad_between_seqshardcoded toFalsein the CP path, and FA2 was entirely disabled forpad_between_seqs + thd. FA3 can natively handle this via itsseqused_q/seqused_kmechanism.Solution
Use FA3's
seqused_q/seqused_ktensors to communicate actual token counts per batch element. Passcu_seqlens_q_paddedfor tensor memory layout while derivingseqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]from the realcu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.Fixes #2399
Type of change
Changes
Please list the changes introduced in this PR:
context_parallel.py
get_fa_args(): Addseqused_q/seqused_kparameters, pass through to FA3 forward and backward positional arg lists (replacing hardcodedNones).cp_p2p_fwd_flash_attn()/cp_p2p_bwd_flash_attn(): Acceptpad_between_seqs,cu_seqlens_q_padded,cu_seqlens_kv_padded. When enabled, derivesequsedtensors and overridecu_seqlensto padded versions (with half-padding for lower-triangle/upper-triangle sections).AttnFuncWithCPAndKVP2P: Threadpad_between_seqsand padded cu_seqlens through all forward/backwardcp_p2p_fwd/bwd_flash_attncall sites. Savectx.pad_between_seqsfor backward.AttnFuncWithCPAndQKVOA2A.forward(): Addpad_between_seqsparameter. When enabled with FA3+THD, derivesequsedand swapcu_seqlensfor padded versions before callingget_fa_args().AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Usezeros_like(notempty_like) for gradient init whenpad_between_seqssince FA3 skips padding positions. Add extraNonein return tuple for the newpad_between_seqsgradient slot.attn_forward_func_with_cp(): Passpad_between_seqsin A2A args list.backends.py
FlashAttention.forward(): Acceptcu_seqlens_q_padded/cu_seqlens_kv_padded. Detectpad_between_seqsby comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and passseqused_q/seqused_k.dot_product_attention.py
cu_seqlens_q_padded/cu_seqlens_kv_paddedthrough toFlashAttention.utils.py
pad_between_seqs + thd. FA3 handles this natively viaseqused.test_attention_with_cp.py
@pytest.mark.parametrize("pad_between_seqs", [False, True])to flash attention CP tests.pad_between_seqs=Truefor non-THD formats, when FA3 is not installed, and fora2a+p2pcomm type (not yet supported).run_attention_with_cp.py
pad_between_seqsthroughgenerate_input_shapes()andrun_dpa_with_cp().pad_between_seqs, setcu_seqlens_qto actual lengths (not just for FusedAttention).nan_to_num(nan=0.0).test_attention.py
_run_dot_product_attention()(previously FlashAttention used original unpadded inputs).cu_seqlens_q_padded/cu_seqlens_kv_paddedandpad_between_seqsto DPA call for FlashAttention backend.pad_between_seqs=Trueto parametrize with skip for non-THD formats.New Tests
CP tests (
test_attention_with_cp.py)Added
@pytest.mark.parametrize("pad_between_seqs", [False, True])totest_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed,a2a+p2pcomm type.5 new tests that run (all
pad_between_seqs=True, thd, bf16):True-p2p-thd-cp_1_0-bf16True-p2p-thd-cp_2_1-bf16True-a2a-thd-cp_1_0-bf16True-a2a-thd-cp_1_2-bf16True-a2a-thd-cp_2_1-bf16Non-CP tests (
test_attention.py)Added
Trueto@pytest.mark.parametrize("pad_between_seqs", [False, True])ontest_dot_product_attention, with skip for non-THD. Also changed_run_dot_product_attentionso FlashAttention uses padded inputs/cu_seqlens and receivespad_between_seqs=True.48 new test IDs collected, but all are skipped because the main parametrize uses
qkv_layout=None(defaults to sbhd, not thd). The non-CPpad_between_seqs+ FA3 code path is exercised indirectly when other test functions calltest_dot_product_attentionwithqkv_layout="thd_thd_thd"(e.g.,test_dpa_softmax_thd).Checklist: