Skip to content

[Draft]Support for score_mod and score_mod_bprop in cuDNN's sdpa#2767

Open
vcherepanov-nv wants to merge 23 commits intoNVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-cdx
Open

[Draft]Support for score_mod and score_mod_bprop in cuDNN's sdpa#2767
vcherepanov-nv wants to merge 23 commits intoNVIDIA:mainfrom
vcherepanov-nv:cudnn-score-mod-cdx

Conversation

@vcherepanov-nv
Copy link
Collaborator

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Plumbing for score_mod and score_mod_bprop in TE/common
  • Plumbing for score_mod and score_mod_bprop in TE/pytorch

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 16, 2026

Greptile Summary

This PR adds end-to-end plumbing for score_mod and score_mod_bprop — cuDNN's flexible-graph score modifier callbacks — through the full TE stack: public C API (fused_attn.h), the F16_arbitrary_seqlen CUDA backend, PyTorch cpp_extensions, FusedAttnFunc, DotProductAttention, and MultiheadAttention. The JAX call sites are also updated to pass the required nullptr placeholders for the new parameters.

Key changes:

  • C++ backend (fused_attn_f16_arbitrary_seqlen.cu): Adds a DLPack tensor bridge, a pybind11-based Python callback trampoline (make_attention_score_modifier), and extends FADescriptor_v1 / graph caching to include score modifier identity.
  • Build system (CMakeLists.txt): Fetches dlpack v1.1 via FetchContent and links pybind11/Python into the shared library.
  • Python stack: All four parameters (score_mod, score_mod_bprop, score_mod_tensors, score_mod_bprop_tensors) are plumbed from MultiheadAttentionDotProductAttentionFusedAttentionFusedAttnFunc with appropriate backend-selection guards (forces F16_arbitrary_seqlen, disables Flash/Unfused/FP8-DPA and context-parallel paths).
  • Tests: Four new pytest tests cover identity score_mod, causal masking via score_mod, and external variant-pack tensors.

Issues found:

  • import cudnn at module level in test_attention.py will cause an ImportError that breaks every test in the file on systems without the cudnn Python package.
  • The FADescriptor_v1 cache key uses the raw PyObject address of score_mod as score_mod_id. Python's GC can reuse addresses, causing a stale cache hit that silently returns the wrong graph for a different score modifier.
  • The score_mod callback lambda captures the raw PyObject * without Py_INCREF, leaving a dangling pointer inside the cached std::function after the Python object could be GC'd.
  • score_mod_tensors signature hash (get_extra_tensor_signature) does not include the data pointer (dl_tensor.data), so two tensors with identical shape/dtype but different data are indistinguishable in the cache key.
  • MultiheadAttention.forward docstring is missing entries for score_mod_tensors and score_mod_bprop_tensors.

Confidence Score: 2/5

  • Not safe to merge as-is: module-level cudnn import will break the full test suite in standard CI, and the cache key design issues risk silent correctness failures in production.
  • The architectural plumbing is sound and well-structured, but there are two concrete bugs that need resolution before merge: (1) the unconditional import cudnn at test-module level is a regression that will cause ImportError for all ~1400 existing tests in any environment without the cudnn Python package, (2) the score_mod cache key uses a raw PyObject address without holding a reference, creating both a potential use-after-free (dangling lambda in cached graph) and a cache-key collision vector via Python GC address reuse. The tensor-signature omission of the data pointer is an additional correctness concern.
  • transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu (cache key design and raw PyObject ownership) and tests/pytorch/attention/test_attention.py (module-level cudnn import)

Important Files Changed

Filename Overview
transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu Adds ~300 lines of new C++ glue code implementing DLPack-based tensor bridging, a Python callback trampoline via pybind11, and cuDNN flexible-graph score modifier plumbing. Two significant issues: (1) cache key uses raw PyObject address without incref, risking silent stale-graph hits if Python GC reuses the address; (2) tensor signature hash omits data pointer, causing distinct tensors with the same metadata to collide in the descriptor cache.
tests/pytorch/attention/test_attention.py Adds four new score_mod test functions with good coverage; however, import cudnn is placed at module level which will cause ImportError and break every test in this large file on environments without the cudnn Python package.
transformer_engine/pytorch/attention/dot_product_attention/utils.py Adds has_score_mod / has_score_mod_bprop flags to AttentionParams and get_attention_backend; correctly disables FlashAttention, UnfusedAttention, FP8-DPA, and non-F16_arbitrary_seqlen fused backends when score_mod is present.
transformer_engine/pytorch/attention/dot_product_attention/backends.py Properly threads score_mod params through FusedAttnFunc forward/backward and saves them on ctx; adds correct assertion blocking context-parallel + score_mod usage. Minor: no assertion blocking score_mod + cuda_graph (but the interaction is safe in practice).
transformer_engine/jax/csrc/extensions/attention.cpp Correctly adapts both forward and backward call sites (workspace query + impl) to the new API with 2 nullptrs for fwd and 4 nullptrs for bwd; previous thread issues are resolved.
transformer_engine/common/CMakeLists.txt Adds FetchContent for dlpack v1.1, finds pybind11, and links Python::Module + pybind11::headers + dlpack to transformer_engine; straightforward and correct.

Sequence Diagram

sequenceDiagram
    participant PY as Python (PyTorch)
    participant FWD as FusedAttnFunc.forward
    participant CPP as pytorch/csrc/attention.cpp
    participant NVTE as nvte_fused_attn_fwd
    participant IMPL as fused_attn_arbitrary_seqlen_fwd_impl
    participant CACHE as FADescriptor_v1 cache
    participant CUDNN as cuDNN graph

    PY->>FWD: forward(q,k,v,...,score_mod,score_mod_tensors)
    FWD->>CPP: fused_attn_fwd(..., score_mod, score_mod_tensors)
    CPP->>NVTE: nvte_fused_attn_fwd(..., score_mod.ptr(), score_mod_tensors.ptr())
    NVTE->>IMPL: fused_attn_arbitrary_seqlen_fwd(..., score_mod, score_mod_tensors)
    IMPL->>CACHE: lookup FADescriptor_v1 (includes score_mod_id=uintptr(score_mod))
    alt cache miss
        IMPL->>IMPL: make_attention_score_modifier(score_mod, score_mod_tensors)
        Note over IMPL: acquires GIL, imports cudnn,<br/>calls Python callback to build graph nodes
        IMPL->>CUDNN: sdpa_options.set_score_mod(lambda)
        CUDNN-->>IMPL: graph built + plan compiled
        IMPL->>CACHE: store (graph, extra_tensor_attrs)
    else cache hit
        CACHE-->>IMPL: cached graph + extra_tensor_attrs
    end
    IMPL->>IMPL: extend_variant_pack_with_extra_tensors<br/>(DLPack tensors → void* pointers)
    IMPL->>CUDNN: graph.execute(variant_pack)
    CUDNN-->>IMPL: output O, aux tensors
    IMPL-->>FWD: O, Aux_CTX_Tensors
    FWD-->>PY: output (+ ctx.score_mod* saved for bwd)
Loading

Comments Outside Diff (4)

  1. tests/pytorch/attention/test_attention.py, line 9 (link)

    Module-level import cudnn breaks entire test file

    import cudnn is placed at the top level of the test module. If cudnn Python bindings are not installed in the test environment, every test in this file will fail with ModuleNotFoundError when the module is imported — not just the new score_mod tests. This would completely break the existing test suite for environments that lack the cudnn Python package.

    The cudnn module is only needed by the helper functions _causal_score_mod and _causal_score_mod_external. The import should be guarded or deferred:

    Alternatively, move the import cudnn inside those two helper functions (or into the test functions themselves behind the @pytest.mark.skipif guard).

  2. transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu, line 780-808 (link)

    Cache key collision due to Python object address reuse

    get_extra_tensor_signature hashes tensor metadata (shape, dtype, device) but never hashes the data pointer of each tensor. Two different tensors with the same shape/dtype/device but different data will produce the same signature, causing FADescriptor_v1 to consider them the same graph configuration and return a stale cached graph.

    The same issue applies to score_mod_id in the descriptor (reinterpret_cast<std::uintptr_t>(score_mod)): Python's memory allocator can reuse a freed PyObject address. If a user creates a new lambda with a different graph structure after the old one is GC'd, and it happens to receive the same address, the cache will return the graph built by the old lambda, silently producing incorrect attention outputs.

    To mitigate both:

    1. Include dl_tensor.data (the raw data pointer) in the tensor signature — this makes distinct tensors distinguishable even when metadata matches.
    2. Consider either holding a py::object (incrementing refcount) inside the cached graph_and_tensors tuple for the score_mod callback, or using a monotonically increasing ID assigned at Python call time rather than the PyObject address.
  3. transformer_engine/common/fused_attn/fused_attn_f16_arbitrary_seqlen.cu, line 870-891 (link)

    Captured callback_ptr not reference-counted — use-after-free risk in cached graphs

    The lambda returned by make_attention_score_modifier captures callback_ptr as a raw void * (a non-owning copy of the PyObject *) without calling Py_INCREF. This lambda is stored inside the cuDNN graph which is placed in the thread_local cache.

    If the caller's Python object that backs callback_ptr is garbage-collected while the graph is still cached, the lambda holds a dangling pointer. Although the lambda is only invoked during the initial build_plans call (cache miss), the raw pointer is retained inside the std::function for the lifetime of the cache entry — and if the cache entry is ever evicted and rebuilt (or if any future cuDNN path calls the modifier again), the behavior is undefined.

    The callback pointer (and extra_tensors_ptr) should be Py_INCREF'd before capture, and Py_DECREF'd when the lambda/graph is destroyed (e.g., via a custom deleter or by holding a py::object instead of a raw void *).

  4. tests/pytorch/attention/test_attention.py, line 321-346 (link)

    Backward correctness of causal score_mod without score_mod_bprop

    test_dpa_score_mod_causal calls block(...) with score_mod=_causal_score_mod but without a corresponding score_mod_bprop. The test then asserts the gradients (dq_sm, dk_sm, dv_sm) match the reference computed via cuDNN's built-in causal masking.

    When score_mod_bprop=None is passed to nvte_fused_attn_bwd, sdpa_backward_options.set_score_mod_bprop() is never called. cuDNN reconstructs the masked attention weights from the saved softmax statistics, but for correctness the backward graph should mirror the forward mask. The test tolerance (atol=1.5e-2, rtol=1.5e-2) is quite loose — it's worth verifying (e.g., with a tighter tolerance or an explicit cuDNN documentation reference) that omitting score_mod_bprop still produces numerically correct gradients when the forward uses a score_mod causal mask.

Last reviewed commit: ff8826e

Comment on lines +332 to +333
window_size_left, window_size_right, bottom_right_diagonal, nullptr,
workspace_tensor.data(), stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing score_mod_tensors nullptr in JAX forward calls

The new nvte_fused_attn_fwd signature requires two new void* parameters (score_mod, score_mod_tensors) before workspace, but both JAX forward call sites only insert one nullptr. This shifts all subsequent arguments by one position, causing a compilation error due to mismatched argument count.

Line 195-196 (workspace query) passes nullptr, query_workspace_tensor.data(), nullptr — only 3 args where 4 are needed after bottom_right_diagonal. Line 332-333 (actual forward) passes nullptr, workspace_tensor.data(), stream — same issue.

Suggested change
window_size_left, window_size_right, bottom_right_diagonal, nullptr,
workspace_tensor.data(), stream);
window_size_left, window_size_right, bottom_right_diagonal, nullptr, nullptr,
workspace_tensor.data(), stream);

Comment on lines +195 to +196
window_size_left, window_size_right, bottom_right_diagonal, nullptr,
query_workspace_tensor.data(), nullptr);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing score_mod_tensors nullptr in JAX forward workspace query

Same issue as the forward impl call: this site also only adds one nullptr but needs two (score_mod and score_mod_tensors).

Suggested change
window_size_left, window_size_right, bottom_right_diagonal, nullptr,
query_workspace_tensor.data(), nullptr);
window_size_left, window_size_right, bottom_right_diagonal, nullptr, nullptr,
query_workspace_tensor.data(), nullptr);

Comment on lines +610 to +611
bottom_right_diagonal, deterministic, false, nullptr, nullptr,
workspace_tensor.data(), stream);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Missing score_mod_tensors/score_mod_bprop_tensors nullptrs in JAX backward calls

The new nvte_fused_attn_bwd signature requires four new void* parameters (score_mod, score_mod_bprop, score_mod_tensors, score_mod_bprop_tensors) before workspace, but both JAX backward call sites only insert two nullptrs. This will cause a compilation error.

This call (impl) needs 4 nullptrs, not 2:

Suggested change
bottom_right_diagonal, deterministic, false, nullptr, nullptr,
workspace_tensor.data(), stream);
bottom_right_diagonal, deterministic, false, nullptr, nullptr,
nullptr, nullptr, workspace_tensor.data(), stream);

The workspace query call at lines 483-484 has the same issue and also needs 4 nullptrs:

        window_size_right, bottom_right_diagonal, deterministic, false, nullptr,
        nullptr, nullptr, nullptr, query_workspace_tensor.data(), nullptr);

vcherepanov-nv and others added 21 commits March 16, 2026 20:09
- tests/pytorch/attention/test_attention.py: add test_dpa_score_mod
  testing DotProductAttention with an identity score_mod callable
  (requires cuDNN >= 9.7.0, F16_arbitrary_seqlen backend, no_mask).
  Pass score_mod_bprop alongside score_mod to satisfy the training
  assertion added in this branch.

- utils.h: add has_score_mod/has_score_mod_bprop bool fields to
  FADescriptor_v1 and include them in operator< so graphs built with
  vs. without a score_mod callback get separate cache entries (fixes
  the score_mod callback never being invoked on cache hit).

- fused_attn_f16_arbitrary_seqlen.cu: replace pybind11/PyGraph
  trampoline with PyCapsule-based trampoline (matching the simpler
  approach from cudnn-score-mod branch). Raw C++ pointers are wrapped
  as PyCapsule objects; GIL is acquired via PyGILState_Ensure. Removes
  the pybind11 and pygraph.h dependencies.

- CMakeLists.txt: remove find_package(pybind11), the pybind11::headers
  link, and the pygraph include dir that were only needed for the old
  trampoline. Python_INCLUDE_DIRS (for <Python.h>) is retained.

Co-Authored-By: Claude Sonnet 4.6 <noreply@anthropic.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
…ooleans

Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
pre-commit-ci bot and others added 2 commits March 16, 2026 20:25
Signed-off-by: Vladimir Cherepanov <vcherepanov@nvidia.com>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant