[PyTorch][Fused Attn] Add support for cuDNN to return Softmax Stats always and Max when return_max_logit=True#2677
Conversation
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR adapts TransformerEngine's fused attention implementation to always request the Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Py as fused_attn_fwd (Python)
participant Cpp as attention.cpp (C++)
participant CUDA as fused_attn_f16_arbitrary_seqlen.cu
participant cuDNN as cuDNN Frontend
Py->>Cpp: tex.fused_attn_fwd(..., return_max_logit)
Cpp->>Cpp: Allocate aux_tensor_pack<br/>[0]=Stats, [1]=Max(if rml), [n]=rng_state
Cpp->>CUDA: nvte_fused_attn_fwd(... Aux_CTX_Tensors)
CUDA->>CUDA: generate_stats=true (always)
CUDA->>cuDNN: sdpa with set_generate_stats(true)<br/>+ set_logit_max(Max) if return_max_logit
cuDNN-->>CUDA: O, Stats, [Max if return_max_logit]
CUDA-->>Cpp: Aux_CTX_Tensors filled:<br/>[Stats, [Max], rng_state, ...]
Cpp-->>Py: output_tensors=[O, Stats, [Max], rng_state, ...]
Note over Py: if return_max_logit:<br/>aux=[Stats, rng_state,...]<br/>max_logit=amax(Max)<br/>else:<br/>aux=output_tensors[1:]
Py->>Cpp: tex.fused_attn_bwd(..., aux_ctx_tensors=[Stats, rng_state,...])
Cpp->>CUDA: Aux_CTX_Tensors[0]=Stats, [1]=rng_state
CUDA->>cuDNN: sdpaBwd(Stats as softmax_stats)
cuDNN-->>CUDA: dQ, dK, dV
Last reviewed commit: ef0d7ec |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…27/TransformerEngine into fix_return_stats_max_cudnn
Additional Comments (1)
The public docstring still describes |
| stats = output_tensors[1] + torch.log(output_tensors[2]) | ||
| # thd: output_tensors: out [tq, h, d], Stats [tq, h, 1], Max [tq, h, 1] | ||
| # bshd: output_tensors: out [b, sq, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] | ||
| # sbhd: output_tensors: out [sq, b, h, d], Stats [b, h, sq, 1], Max [b, h, sq, 1] (there's no typo here) |
There was a problem hiding this comment.
Do we need the "there's no typo here" :)
There was a problem hiding this comment.
I deliberately added it because I didn't believe it and checked the shapes myself :P
transformer_engine/common/include/transformer_engine/fused_attn.h
Outdated
Show resolved
Hide resolved
…eturn_stats_max_cudnn
…27/TransformerEngine into fix_return_stats_max_cudnn
21ca43a to
becc3ad
Compare
Additional Comments (1)
|
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
d4568db to
8f40cab
Compare
…eturn_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
2b64738 to
e005455
Compare
|
/te-ci L2 |
Signed-off-by: Sudhakar Singh <sudhakars@nvidia.com>
…eturn_stats_max_cudnn
Description
cuDNN recently made returning any subset of {Stats, SumExp, Max} possible. This PR adapts TE to always get
Statsfrom cuDNN andMaxtensor ifreturn_max_logit=True. (Note thatStats= log(SumExp)+Max)Type of change
Changes
Please list the changes introduced in this PR:
fused_attn_f16_arbitrary_seqlen.cuSumExptensor as it's not needed since cuDNN returnsStatsby default.generate_stats=Truewhich forces cuDNN to always returnStatstensor (needed in the backward pass)transformer_engine/pytorch/cpp_extensions/fused_attn.pyStats = log(SumExp) + Maxsince cuDNN returnsStatsdirectly and TE doesn't needSumExpfrom cuDNNChecklist: