Skip to content

[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD#2596

Open
sudhakarsingh27 wants to merge 11 commits intoNVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs
Open

[PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD#2596
sudhakarsingh27 wants to merge 11 commits intoNVIDIA:mainfrom
sudhakarsingh27:flash_attn_pad_bw_seqs

Conversation

@sudhakarsingh27
Copy link
Collaborator

@sudhakarsingh27 sudhakarsingh27 commented Jan 14, 2026

Description

TLDR

Enable pad_between_seqs=True for FlashAttention 3 with THD format — both for context parallelism (A2A and P2P comm types) and non-CP paths. Previously pad_between_seqs was only supported with FusedAttention.

Problem

When using THD format with variable-length sequences, sequences are padded for divisibility across CP ranks. With pad_between_seqs=True, the attention kernel needs to know actual (unpadded) token counts so it doesn't compute attention over padding tokens. FusedAttention already handled this via cu_seqlens_q_padded, but FlashAttention (both FA2 and FA3) had pad_between_seqs hardcoded to False in the CP path, and FA2 was entirely disabled for pad_between_seqs + thd. FA3 can natively handle this via its seqused_q/seqused_k mechanism.

Solution

Use FA3's seqused_q/seqused_k tensors to communicate actual token counts per batch element. Pass cu_seqlens_q_padded for tensor memory layout while deriving seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1] from the real cu_seqlens. This applies to both the CP path (A2A and P2P) and the non-CP path.

Fixes #2399

Type of change

  • New feature (non-breaking change which adds functionality)

Changes

Please list the changes introduced in this PR:

context_parallel.py

  • get_fa_args(): Add seqused_q/seqused_k parameters, pass through to FA3 forward and backward positional arg lists (replacing hardcoded Nones).
  • cp_p2p_fwd_flash_attn() / cp_p2p_bwd_flash_attn(): Accept pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded. When enabled, derive seqused tensors and override cu_seqlens to padded versions (with half-padding for lower-triangle/upper-triangle sections).
  • AttnFuncWithCPAndKVP2P: Thread pad_between_seqs and padded cu_seqlens through all forward/backward cp_p2p_fwd/bwd_flash_attn call sites. Save ctx.pad_between_seqs for backward.
  • AttnFuncWithCPAndQKVOA2A.forward(): Add pad_between_seqs parameter. When enabled with FA3+THD, derive seqused and swap cu_seqlens for padded versions before calling get_fa_args().
  • AttnFuncWithCPAndQKVOA2A.backward(): Same seqused/cu_seqlens override. Use zeros_like (not empty_like) for gradient init when pad_between_seqs since FA3 skips padding positions. Add extra None in return tuple for the new pad_between_seqs gradient slot.
  • attn_forward_func_with_cp(): Pass pad_between_seqs in A2A args list.

backends.py

  • FlashAttention.forward(): Accept cu_seqlens_q_padded/cu_seqlens_kv_padded. Detect pad_between_seqs by comparing padded vs actual cu_seqlens. Pass padded cu_seqlens to CP path. For non-CP FA3 path, derive and pass seqused_q/seqused_k.

dot_product_attention.py

  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded through to FlashAttention.

utils.py

  • Only disable FA2 (not FA3) when pad_between_seqs + thd. FA3 handles this natively via seqused.

test_attention_with_cp.py

  • Add @pytest.mark.parametrize("pad_between_seqs", [False, True]) to flash attention CP tests.
  • Skip pad_between_seqs=True for non-THD formats, when FA3 is not installed, and for a2a+p2p comm type (not yet supported).

run_attention_with_cp.py

  • Thread pad_between_seqs through generate_input_shapes() and run_dpa_with_cp().
  • When pad_between_seqs, set cu_seqlens_q to actual lengths (not just for FusedAttention).
  • Handle FA3 backward NaN at padding positions: nan_to_num(nan=0.0).
  • Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros at padding slots).
  • Add tensor names to NaN/Inf assertion messages for debuggability.

test_attention.py

  • Group FlashAttention with FusedAttention for padded input/output handling in _run_dot_product_attention() (previously FlashAttention used original unpadded inputs).
  • Pass cu_seqlens_q_padded/cu_seqlens_kv_padded and pad_between_seqs to DPA call for FlashAttention backend.
  • Add pad_between_seqs=True to parametrize with skip for non-THD formats.

New Tests

CP tests (test_attention_with_cp.py)

Added @pytest.mark.parametrize("pad_between_seqs", [False, True]) to test_cp_with_flash_attention. Skip conditions: non-THD formats, FA3 not installed, a2a+p2p comm type.

5 new tests that run (all pad_between_seqs=True, thd, bf16):

Test CP comm Model config
True-p2p-thd-cp_1_0-bf16 P2P causal, 1 head
True-p2p-thd-cp_2_1-bf16 P2P causal, 2 heads
True-a2a-thd-cp_1_0-bf16 A2A causal, 1 head
True-a2a-thd-cp_1_2-bf16 A2A causal, sliding window
True-a2a-thd-cp_2_1-bf16 A2A causal, 2 heads

Non-CP tests (test_attention.py)

Added True to @pytest.mark.parametrize("pad_between_seqs", [False, True]) on test_dot_product_attention, with skip for non-THD. Also changed _run_dot_product_attention so FlashAttention uses padded inputs/cu_seqlens and receives pad_between_seqs=True.

48 new test IDs collected, but all are skipped because the main parametrize uses qkv_layout=None (defaults to sbhd, not thd). The non-CP pad_between_seqs + FA3 code path is exercised indirectly when other test functions call test_dot_product_attention with qkv_layout="thd_thd_thd" (e.g., test_dpa_softmax_thd).

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@sudhakarsingh27 sudhakarsingh27 self-assigned this Jan 14, 2026
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Jan 14, 2026

Greptile Summary

This PR enables pad_between_seqs=True for FlashAttention 3 in THD format for both the CP (A2A and P2P) and non-CP paths. The approach is to pass seqused_q/seqused_k tensors (derived from actual cu_seqlens) to FA3 to mask out padding tokens, while providing the padded cu_seqlens for memory layout. FA2 is explicitly disabled for pad_between_seqs (as before), while FA3 is now enabled.

Key changes:

  • context_parallel.py: get_fa_args() gains seqused_q/seqused_k slots; P2P forward/backward and A2A forward/backward derive seqused from per-step actual cu_seqlens and pass padded cu_seqlens as memory layout; backward gradient init uses zeros_like instead of empty_like when pad_between_seqs=True to ensure unwritten padding positions stay zero.
  • backends.py: FlashAttention.forward accepts pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded; non-CP FA3 path adds seqused_q/seqused_k kwargs; CP path passes padded cu_seqlens for memory layout.
  • utils.py: Only FA2 is disabled for pad_between_seqs + thd; FA3 remains enabled.
  • Tests: run_attention_with_cp.py handles FA3's non-guarantee of zeroed padding positions by explicit zeroing before comparison; CP and non-CP test parametrize now includes pad_between_seqs=True.

One logic concern: in cp_p2p_fwd_flash_attn and cp_p2p_bwd_flash_attn, the cu_seqlens_q_padded placed in flash_attn_inputs (lines 1688–1689) is the global padded value (sum across all CP ranks), but q_part is the local per-rank tensor. FA3 uses cu_seqlens_q to index into the Q buffer, so for batch_size > 1 this may cause out-of-bounds access beyond the local tensor for batch elements after the first. The per-rank equivalent (cu_seqlens_q_padded // cp_size) should be used instead.

Confidence Score: 2/5

  • Not safe to merge without investigating the global vs. per-rank cu_seqlens_q_padded issue in the P2P path.
  • The overall approach (using seqused_q/seqused_k for FA3) is well-designed and the A2A and non-CP paths look correct. However, in cp_p2p_fwd_flash_attn and cp_p2p_bwd_flash_attn, the global cu_seqlens_q_padded is used as the FA3 memory layout for the per-rank local Q tensor. For batch_size > 1, FA3 would compute incorrect memory offsets for batch elements after the first, potentially causing out-of-bounds CUDA memory access. While the 5 new tests appear to pass, they may not exercise the multi-batch element out-of-bounds scenario definitively. Until this is confirmed to be either correct (with explanation) or fixed, the confidence is low.
  • transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py — specifically the flash_attn_inputs construction in AttnFuncWithCPAndKVP2P.forward (lines ~1687–1689) and the cu_seqlens_q_ = cu_seqlens_q_padded assignment in cp_p2p_fwd_flash_attn / cp_p2p_bwd_flash_attn.

Important Files Changed

Filename Overview
transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py Core CP implementation: adds seqused_q/seqused_k support to P2P forward/backward and A2A forward/backward. The A2A path looks correct (uses local cu_seqlens_q before override). The P2P path has a potential bug: global cu_seqlens_q_padded is used as FA3's memory layout while q_part is a local (per-rank) tensor, creating a possible out-of-bounds access for batch elements after the first.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Adds pad_between_seqs, cu_seqlens_q_padded, cu_seqlens_kv_padded parameters to FlashAttention.forward. Correctly wires padded cu_seqlens to the CP path and seqused_q/seqused_k kwargs to the non-CP FA3 path. Minor: seqused_q/seqused_k block missing a qkv_format == "thd" guard.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Correctly narrows the FA disable block to only FA2 when pad_between_seqs=True, leaving FA3 enabled. The change is minimal and consistent with the rest of the filter logic.
tests/pytorch/attention/run_attention_with_cp.py Threads pad_between_seqs through test harness. Adds tensor names to NaN assertion messages, handles FA3 non-guarantee of zeroed padding via explicit zeroing before comparison, and switches from empty_like-style semantics to zeros_like for gradient init. Minor inefficiency: unconditional .clone() in the THD+training path applies even when pad_between_seqs=False.
tests/pytorch/attention/test_attention_with_cp.py Adds pad_between_seqs parametrize with correct skip conditions (non-THD, FA3 not installed, unsupported a2a+p2p comm type). Clean integration with existing subprocess-based test runner.

Sequence Diagram

sequenceDiagram
    participant User as User / DotProductAttention
    participant FA as FlashAttention.forward
    participant CP as attn_forward_func_with_cp
    participant P2P as AttnFuncWithCPAndKVP2P
    participant PrepQKV as cp_p2p_fwd_prepare_qkv
    participant FlashFwd as cp_p2p_fwd_flash_attn (FA3)
    participant A2A as AttnFuncWithCPAndQKVOA2A

    User->>FA: q, k, v, cu_seqlens_q, cu_seqlens_q_padded, pad_between_seqs=True
    FA->>CP: forward args (global cu_seqlens_q_padded when pad_between_seqs)
    CP->>P2P: apply(cu_seqlens_q [global actual], cu_seqlens_q_padded [global padded])
    P2P->>PrepQKV: get_cu_seqlens_on_cp_rank(cu_seqlens_q, cu_seqlens_q_padded, rank)
    PrepQKV-->>P2P: cu_seqlens_q_per_step [local actual]
    P2P->>FlashFwd: flash_attn_inputs (cu_seqlens_q_padded=GLOBAL), prepare_outputs (cu_seqlens_q_per_step=LOCAL)
    Note over FlashFwd: seqused_q = per_step[1:] - per_step[:-1] (local actual)<br/>cu_seqlens_q_ = cu_seqlens_q_padded (GLOBAL ⚠️ should be LOCAL)
    FlashFwd->>FlashFwd: FA3 forward with seqused_q + cu_seqlens_q_ (mismatch for batch>1)

    CP->>A2A: apply(cu_seqlens_q [global actual], cu_seqlens_q_padded [global padded], pad_between_seqs)
    Note over A2A: seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]<br/>fa_cu_seqlens_q = cu_seqlens_q_padded (local context, ✓)
    A2A->>A2A: FA3 forward with seqused_q + padded cu_seqlens
Loading

Comments Outside Diff (2)

  1. transformer_engine/pytorch/attention/dot_product_attention/context_parallel.py, line 960-965 (link)

    Global vs. per-rank cu_seqlens_q_padded mismatch for FA3 memory layout

    When pad_between_seqs=True, the cu_seqlens_q_padded placed in flash_attn_inputs is the global padded cumulative sequence lengths (total across all CP ranks), but q_part in P2P is the local per-rank Q tensor (approximately total_tokens / cp_size tokens, produced by thd_get_partitioned_indices + index_select).

    FA3 uses cu_seqlens_q to index into the Q buffer: for batch element b, it reads q[cu_seqlens_q[b] : cu_seqlens_q[b] + seqused_q[b]]. With global cu_seqlens_q_padded[1] ≈ total_padded_tokens / 2 and a local Q buffer of size total_padded_tokens / 2, batch element 1's access starts at the very end of the local buffer and seqused_q[1] > 0 tokens are read beyond it — an out-of-bounds access.

    Note the contrast with the non-pad_between_seqs path (line 935), which correctly uses cu_seqlens_q_per_step (a per-rank value from cp_p2p_fwd_prepare_qkv):

    cu_seqlens_q_ = cu_seqlens_q_per_step   # per-rank (correct for local Q)

    The fix should use per-rank padded cu_seqlens here too (e.g., the value returned by get_cu_seqlens_on_cp_rank(..., include_padding=True) or roughly cu_seqlens_q_padded // cp_size for uniform distribution):

    # For the diagonal/all section:
    cu_seqlens_q_ = cu_seqlens_q_padded // cp_size   # per-rank padded layout

    The same issue exists in the backward pass in cp_p2p_bwd_flash_attn at the analogous block where cu_seqlens_q_bwd = cu_seqlens_q_padded.

  2. transformer_engine/pytorch/attention/dot_product_attention/backends.py, line 1038-1044 (link)

    Missing qkv_format == "thd" guard for seqused_q/seqused_k computation

    cu_seqlens_q is None for non-THD formats in the FA3 path. If pad_between_seqs=True is ever passed to FlashAttention.forward with a non-THD format (e.g., by bypassing get_attention_backend), the indexing cu_seqlens_q[1:] - cu_seqlens_q[:-1] would raise a TypeError. A defensive guard consistent with the CP-path treatment would be:

Last reviewed commit: "Merge branch 'main' ..."

Copy link
Contributor

@greptile-apps greptile-apps bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

4 files reviewed, 1 comment

Edit Code Review Agent Settings | Greptile

Comment on lines +974 to +983
# if `pad_between_seqs` is True, provide flash_attn_3 with `seqused_q` and `seqused_k`
# in addition to `cu_seqlens_q_padded` and `cu_seqlens_kv_padded` to avoid affecting the
# padding positions.
if pad_between_seqs:
fa_3_optional_forward_kwargs["seqused_q"] = (
cu_seqlens_q[1:] - cu_seqlens_q[:-1]
)
fa_3_optional_forward_kwargs["seqused_k"] = (
cu_seqlens_kv[1:] - cu_seqlens_kv[:-1]
)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: verify that flash_attn_3 with seqused_q/seqused_k truly avoids writing to padding positions - the related issue #2391 mentions "we need to manually set the output of the padded positions to zero" (similar to how FusedAttention zeroes output in C++ for THD format). if flash_attn_3 doesn't zero these internally, output may have garbage values in padded positions. have you verified that flash_attn_3 correctly handles padding internally with these parameters?

## TLDR
Enable `pad_between_seqs=True` for A2A and P2P context parallelism comm types
with FlashAttention 3 and THD format. Previously `pad_between_seqs` was only
supported with FusedAttention.

## Problem
When using THD format with variable-length sequences, sequences are padded for
divisibility across CP ranks. With `pad_between_seqs=True`, the attention kernel
needs to know actual (unpadded) token counts so it doesn't compute attention over
padding tokens. FusedAttention already handled this via `cu_seqlens_q_padded`, but
FlashAttention (both FA2 and FA3) had `pad_between_seqs` hardcoded to `False` in
the CP path, and FA2 was entirely disabled for `pad_between_seqs + thd`. FA3 can
natively handle this via its `seqused_q`/`seqused_k` mechanism.

## Solution
Use FA3's `seqused_q`/`seqused_k` tensors to communicate actual token counts per
batch element. Pass `cu_seqlens_q_padded` for tensor memory layout while deriving
`seqused_q = cu_seqlens_q[1:] - cu_seqlens_q[:-1]` from the real `cu_seqlens`.

## Changes

### context_parallel.py
- `get_fa_args()`: Add `seqused_q`/`seqused_k` parameters, pass through to FA3
  forward and backward positional arg lists (replacing hardcoded `None`s).
- `cp_p2p_fwd_flash_attn()` / `cp_p2p_bwd_flash_attn()`: Accept `pad_between_seqs`,
  `cu_seqlens_q_padded`, `cu_seqlens_kv_padded`. When enabled, derive `seqused`
  tensors and override `cu_seqlens` to padded versions (with half-padding for
  lower-triangle/upper-triangle sections).
- `AttnFuncWithCPAndKVP2P`: Thread `pad_between_seqs` and padded cu_seqlens
  through all forward/backward `cp_p2p_fwd/bwd_flash_attn` call sites. Save
  `ctx.pad_between_seqs` for backward.
- `AttnFuncWithCPAndQKVOA2A.forward()`: Add `pad_between_seqs` parameter. When
  enabled with FA3+THD, derive `seqused` and swap `cu_seqlens` for padded versions
  before calling `get_fa_args()`.
- `AttnFuncWithCPAndQKVOA2A.backward()`: Same seqused/cu_seqlens override.
  Use `zeros_like` (not `empty_like`) for gradient init when `pad_between_seqs`
  since FA3 skips padding positions. Add extra `None` in return tuple for the
  new `pad_between_seqs` gradient slot.
- `attn_forward_func_with_cp()`: Pass `pad_between_seqs` in A2A args list.

### backends.py
- `FlashAttention.forward()`: Accept `cu_seqlens_q_padded`/`cu_seqlens_kv_padded`.
  Detect `pad_between_seqs` by comparing padded vs actual cu_seqlens. Pass padded
  cu_seqlens to CP path. For non-CP FA3 path, derive and pass `seqused_q`/`seqused_k`.

### dot_product_attention.py
- Pass `cu_seqlens_q_padded`/`cu_seqlens_kv_padded` through to `FlashAttention`.

### utils.py
- Only disable FA2 (not FA3) when `pad_between_seqs + thd`. FA3 handles this
  natively via `seqused`.

### test_attention_with_cp.py
- Add `@pytest.mark.parametrize("pad_between_seqs", [False, True])` to flash
  attention CP tests.
- Skip `pad_between_seqs=True` for non-THD formats, when FA3 is not installed,
  and for `a2a+p2p` comm type (not yet supported).

### run_attention_with_cp.py
- Thread `pad_between_seqs` through `generate_input_shapes()` and `run_dpa_with_cp()`.
- When `pad_between_seqs`, set `cu_seqlens_q` to actual lengths (not just for
  FusedAttention).
- Handle FA3 backward NaN at padding positions: `nan_to_num(nan=0.0)`.
- Zero padding positions explicitly before comparison (FA3 doesn't guarantee zeros
  at padding slots).
- Add tensor names to NaN/Inf assertion messages for debuggability.

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
@sudhakarsingh27 sudhakarsingh27 force-pushed the flash_attn_pad_bw_seqs branch from ea51821 to e338049 Compare March 10, 2026 23:37
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L2

@sudhakarsingh27 sudhakarsingh27 changed the title Flash attn pad bw seqs [PyTorch] Add pad_between_seqs support for A2A and P2P CP with FA3 + THD Mar 11, 2026
Enable FlashAttention backend in test_attention.py to use padded
cu_seqlens and pad_between_seqs parameter, matching FusedAttention's
test path. FA3 natively supports pad_between_seqs via seqused_q/seqused_k.

- Group FlashAttention with FusedAttention for padded input/output handling
- Pass cu_seqlens_q_padded/cu_seqlens_kv_padded for FlashAttention backend
- Pass pad_between_seqs to DPA call
- Add pad_between_seqs=True to parametrize with thd-only skip

Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>

pad_between_seqs = False
if qkv_format == "thd" and cu_seqlens_q_padded is not None:
pad_between_seqs = not torch.equal(cu_seqlens_q_padded, cu_seqlens_q)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can pad_between_seqs be decided ahead of time, passed by the user or something? This wouldn't be CUDA Graph-compatible right?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This pattern exists in dpa.py as well. But yes, it's definitely redundant here

sudhakarsingh27 and others added 5 commits March 19, 2026 20:12
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…ransformerEngine into flash_attn_pad_bw_seqs
@sudhakarsingh27
Copy link
Collaborator Author

/te-ci pytorch L1

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

Projects

None yet

Development

Successfully merging this pull request may close these issues.

Support FlashAttention with pad_between_seqs=True

2 participants