[Draft]Support for score_mod and score_mod_bprop in cuDNN's sdpa#2767
[Draft]Support for score_mod and score_mod_bprop in cuDNN's sdpa#2767vcherepanov-nv wants to merge 23 commits intoNVIDIA:mainfrom
Conversation
Greptile SummaryThis PR adds end-to-end plumbing for Key changes:
Issues found:
Confidence Score: 2/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant PY as Python (PyTorch)
participant FWD as FusedAttnFunc.forward
participant CPP as pytorch/csrc/attention.cpp
participant NVTE as nvte_fused_attn_fwd
participant IMPL as fused_attn_arbitrary_seqlen_fwd_impl
participant CACHE as FADescriptor_v1 cache
participant CUDNN as cuDNN graph
PY->>FWD: forward(q,k,v,...,score_mod,score_mod_tensors)
FWD->>CPP: fused_attn_fwd(..., score_mod, score_mod_tensors)
CPP->>NVTE: nvte_fused_attn_fwd(..., score_mod.ptr(), score_mod_tensors.ptr())
NVTE->>IMPL: fused_attn_arbitrary_seqlen_fwd(..., score_mod, score_mod_tensors)
IMPL->>CACHE: lookup FADescriptor_v1 (includes score_mod_id=uintptr(score_mod))
alt cache miss
IMPL->>IMPL: make_attention_score_modifier(score_mod, score_mod_tensors)
Note over IMPL: acquires GIL, imports cudnn,<br/>calls Python callback to build graph nodes
IMPL->>CUDNN: sdpa_options.set_score_mod(lambda)
CUDNN-->>IMPL: graph built + plan compiled
IMPL->>CACHE: store (graph, extra_tensor_attrs)
else cache hit
CACHE-->>IMPL: cached graph + extra_tensor_attrs
end
IMPL->>IMPL: extend_variant_pack_with_extra_tensors<br/>(DLPack tensors → void* pointers)
IMPL->>CUDNN: graph.execute(variant_pack)
CUDNN-->>IMPL: output O, aux tensors
IMPL-->>FWD: O, Aux_CTX_Tensors
FWD-->>PY: output (+ ctx.score_mod* saved for bwd)
|
| window_size_left, window_size_right, bottom_right_diagonal, nullptr, | ||
| workspace_tensor.data(), stream); |
There was a problem hiding this comment.
Missing score_mod_tensors nullptr in JAX forward calls
The new nvte_fused_attn_fwd signature requires two new void* parameters (score_mod, score_mod_tensors) before workspace, but both JAX forward call sites only insert one nullptr. This shifts all subsequent arguments by one position, causing a compilation error due to mismatched argument count.
Line 195-196 (workspace query) passes nullptr, query_workspace_tensor.data(), nullptr — only 3 args where 4 are needed after bottom_right_diagonal. Line 332-333 (actual forward) passes nullptr, workspace_tensor.data(), stream — same issue.
| window_size_left, window_size_right, bottom_right_diagonal, nullptr, | |
| workspace_tensor.data(), stream); | |
| window_size_left, window_size_right, bottom_right_diagonal, nullptr, nullptr, | |
| workspace_tensor.data(), stream); |
| window_size_left, window_size_right, bottom_right_diagonal, nullptr, | ||
| query_workspace_tensor.data(), nullptr); |
There was a problem hiding this comment.
Missing score_mod_tensors nullptr in JAX forward workspace query
Same issue as the forward impl call: this site also only adds one nullptr but needs two (score_mod and score_mod_tensors).
| window_size_left, window_size_right, bottom_right_diagonal, nullptr, | |
| query_workspace_tensor.data(), nullptr); | |
| window_size_left, window_size_right, bottom_right_diagonal, nullptr, nullptr, | |
| query_workspace_tensor.data(), nullptr); |
| bottom_right_diagonal, deterministic, false, nullptr, nullptr, | ||
| workspace_tensor.data(), stream); |
There was a problem hiding this comment.
Missing score_mod_tensors/score_mod_bprop_tensors nullptrs in JAX backward calls
The new nvte_fused_attn_bwd signature requires four new void* parameters (score_mod, score_mod_bprop, score_mod_tensors, score_mod_bprop_tensors) before workspace, but both JAX backward call sites only insert two nullptrs. This will cause a compilation error.
This call (impl) needs 4 nullptrs, not 2:
| bottom_right_diagonal, deterministic, false, nullptr, nullptr, | |
| workspace_tensor.data(), stream); | |
| bottom_right_diagonal, deterministic, false, nullptr, nullptr, | |
| nullptr, nullptr, workspace_tensor.data(), stream); |
The workspace query call at lines 483-484 has the same issue and also needs 4 nullptrs:
window_size_right, bottom_right_diagonal, deterministic, false, nullptr,
nullptr, nullptr, nullptr, query_workspace_tensor.data(), nullptr);- tests/pytorch/attention/test_attention.py: add test_dpa_score_mod testing DotProductAttention with an identity score_mod callable (requires cuDNN >= 9.7.0, F16_arbitrary_seqlen backend, no_mask). Pass score_mod_bprop alongside score_mod to satisfy the training assertion added in this branch. - utils.h: add has_score_mod/has_score_mod_bprop bool fields to FADescriptor_v1 and include them in operator< so graphs built with vs. without a score_mod callback get separate cache entries (fixes the score_mod callback never being invoked on cache hit). - fused_attn_f16_arbitrary_seqlen.cu: replace pybind11/PyGraph trampoline with PyCapsule-based trampoline (matching the simpler approach from cudnn-score-mod branch). Raw C++ pointers are wrapped as PyCapsule objects; GIL is acquired via PyGILState_Ensure. Removes the pybind11 and pygraph.h dependencies. - CMakeLists.txt: remove find_package(pybind11), the pybind11::headers link, and the pygraph include dir that were only needed for the old trampoline. Python_INCLUDE_DIRS (for <Python.h>) is retained. Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…ooleans Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
cd86c51 to
9d0843f
Compare
for more information, see https://pre-commit.ci
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: