Skip to content

Enable fused RMSNorm dLN + add through CUDNN#2778

Open
CarlosGomes98 wants to merge 7 commits intoNVIDIA:mainfrom
CarlosGomes98:cgomes/fuse_dLN_add
Open

Enable fused RMSNorm dLN + add through CUDNN#2778
CarlosGomes98 wants to merge 7 commits intoNVIDIA:mainfrom
CarlosGomes98:cgomes/fuse_dLN_add

Conversation

@CarlosGomes98
Copy link

Description

From CUDNN 9.21, CUDNN introduces kernel support for fusing the dLN + add backward pattern, useful for the residual fork + rmsnorm pattern in networks.
The fusion pattern is already possible through te.ops.Sequential with backward_add_rmsnorm.py, this just enables the CUDNN backend.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Enable the CUDNN backend for rmsnorm_bwd_add, which was previously hardcoded to TE-only.
  • Add support for the residual add operation through the BackwardAdd stage in building the CUDNN graph.

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

CarlosGomes98 and others added 4 commits March 18, 2026 12:26
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 18, 2026

Greptile Summary

This PR enables the cuDNN backend for rmsnorm_bwd_add (the fused RMSNorm backward + residual-add pattern), previously hardcoded to the TE-only kernel. Starting with cuDNN 9.21, a pointwise ADD node is appended to the cuDNN rmsnorm_backward graph, fusing dx = rmsnorm_dx + add into a single kernel dispatch. The change is consistent with how the existing rmsnorm_bwd cuDNN path is structured and follows the same plan-registry caching, workspace-query, and variant-pack patterns.

Key changes:

  • common.h: Adds _add shared_ptr member to CudnnNormalizationPlan for the BackwardAdd tensor.
  • common.cpp: Builds the cuDNN pointwise-ADD node when _norm_stage == BackwardAdd, gates on cudnnGetVersion() >= 92100, and removes the previous hard-fail on non-null add_dptr. The _add tensor is typed as wtype (= gamma.data.dtype), consistent with the dtype validation in rmsnorm_bwd_add.
  • rmsnorm_api.cpp: Replaces the unconditional NVTE_Norm_Backend::Te with the same cuDNN/TE branch logic already used by rmsnorm_bwd; zero-centered-gamma restriction is now limited to the TE backend path; a mirrored // TODO: add check for GPU ARCH is added for parity with sibling code paths.

Confidence Score: 4/5

  • PR is safe to merge; the implementation correctly follows existing cuDNN-frontend patterns and the new BackwardAdd graph construction is logically sound for the symmetric-dtype case that is enforced by input validation.
  • The change is narrowly scoped, mirrors the established rmsnorm_bwd cuDNN path, and the _add tensor dtype (wtype) correctly matches the upstream dtype validation. The cuDNN version guard (>= 9.21) inside the constructor will throw a clear error if the requirement isn't met with an explicit opt-in env var. The GPU-arch TODO mirrors pre-existing sibling code and unsupported architectures will surface through check_support at plan build time. No tests were added for the new path, which is the primary risk.
  • No files require special attention beyond confirming test coverage for the new cuDNN BackwardAdd path.

Important Files Changed

Filename Overview
transformer_engine/common/normalization/common.h Adds _add shared_ptr member to CudnnNormalizationPlan for the BackwardAdd fused tensor; minimal, clean header change.
transformer_engine/common/normalization/common.cpp Adds cuDNN graph construction for the fused BackwardAdd stage (pointwise ADD after rmsnorm backward, cuDNN ≥ 9.21 gate) and updates the execute path to bind the _add tensor when applicable; removes the old hard-fail on non-null add_dptr. Logic is sound for common symmetric-type cases.
transformer_engine/common/normalization/rmsnorm/rmsnorm_api.cpp Lifts the hard-coded NVTE_Norm_Backend::Te for rmsnorm_bwd_add, adding a cuDNN branch consistent with rmsnorm_bwd; alignment and zero-centered-gamma checks are now backend-conditional. Mirrors the pre-existing TODO for GPU arch check.

Sequence Diagram

sequenceDiagram
    participant Caller
    participant rmsnorm_bwd_add
    participant NormPlanRegistry
    participant CudnnNormalizationPlan
    participant cuDNN

    Caller->>rmsnorm_bwd_add: dz, x, add, rsigma, gamma, dx, dgamma
    rmsnorm_bwd_add->>rmsnorm_bwd_add: validate dtypes & shapes

    alt use_cudnn_norm_bwd()
        rmsnorm_bwd_add->>NormPlanRegistry: getNormalizationPlan(Cudnn, RMSNorm, BackwardAdd, ...)
        NormPlanRegistry->>CudnnNormalizationPlan: construct (wtype, itype, otype)
        CudnnNormalizationPlan->>CudnnNormalizationPlan: NVTE_CHECK(cudnnGetVersion >= 92100)
        CudnnNormalizationPlan->>cuDNN: rmsnorm_backward(dz, x, gamma, rsigma)
        CudnnNormalizationPlan->>cuDNN: pointwise ADD(_dx, _add) → _dx_with_add
        CudnnNormalizationPlan->>cuDNN: _build() / check_support / build_plans
        NormPlanRegistry-->>rmsnorm_bwd_add: plan (cached)
        rmsnorm_bwd_add->>CudnnNormalizationPlan: execute(x, gamma, rsigma, dx, dz, add, dgamma)
        CudnnNormalizationPlan->>CudnnNormalizationPlan: bind variant_pack (_x, _rsigma, _dz, _dgamma, _dx=_dx_with_add, _add)
        CudnnNormalizationPlan->>cuDNN: graph.execute(variant_pack, workspace)
        cuDNN-->>Caller: dx (= rmsnorm_dx + add), dgamma
    else TE backend
        rmsnorm_bwd_add->>NormPlanRegistry: getNormalizationPlan(Te, RMSNorm, BackwardAdd, ...)
        NormPlanRegistry-->>rmsnorm_bwd_add: plan (cached)
        rmsnorm_bwd_add->>Caller: dx, dgamma (TE kernel)
    end
Loading

Last reviewed commit: "remove dangling todo"

Comment on lines +399 to +401
if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) {
NVTE_CHECK(cudnnGetVersion() >= 92100,
"Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion());
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Version check throws inside constructor with no fallback path

The NVTE_CHECK(cudnnGetVersion() >= 92100, ...) fires during CudnnNormalizationPlan construction, which is called from getNormalizationPlan. If a user sets NVTE_NORM_BWD_USE_CUDNN=1 on a system with cuDNN < 9.21, the entire rmsnorm_bwd_add call will throw a hard error rather than gracefully falling back to the TE kernel. Unlike typical cuDNN support checks elsewhere in the codebase that gate entry into the cuDNN path, this check is buried inside plan construction where there is no recovery mechanism.

Consider validating the cuDNN version at the top of rmsnorm_bwd_add in rmsnorm_api.cpp (inside the use_cudnn_norm_bwd() branch) and falling back to the TE backend when the version requirement isn't met, consistent with how other backend-selection logic is structured.

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure what the best behaviour here is. If NVTE_NORM_BWD_USE_CUDNN=1 is set, and we cannot use CUDNN, I would expect it to fail early, and for the user to remove that env var, rather than silently falling back to TE.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

You're right – if a user explicitly sets NVTE_NORM_BWD_USE_CUDNN=1, they're requesting a specific backend and should expect it to fail loudly rather than silently fall back. The explicit opt-in makes the current behavior reasonable.

One clarification: does the regular (non-BackwardAdd) cuDNN normalization work with cuDNN < 9.21? If so, a user might set NVTE_NORM_BWD_USE_CUDNN=1 for general cuDNN norm support and be surprised when BackwardAdd specifically fails. In that case, the error message could be slightly more actionable:

"Fused BackwardAdd requires cuDNN >= 9.21.0 (found " + std::to_string(cudnnGetVersion()) + 
"). Either upgrade cuDNN or unset NVTE_NORM_BWD_USE_CUDNN to use the TE backend."

But if BackwardAdd is only used when explicitly requested (not auto-selected based on the operation pattern), then the current behavior and message are fine as-is.

Comment on lines 489 to 490
_variant_pack = {
{_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}};
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 _dgamma mapped in variant pack but not an output tensor for BackwardAdd stage

In the variant pack initialization {_dgamma, dgamma_dptr} is always included. For the BackwardAdd stage, _dgamma->set_output(true) is called from line 417 (_dgamma->set_output(true).set_data_type(...)), so cuDNN will write to dgamma_dptr. This looks fine.

However, a concern worth verifying: in the BackwardAdd path, _dx (the member) has been reassigned to _dx_with_add. The original intermediate rmsnorm _dx tensor (the input to the pointwise add) was set to set_output(false). It is not in the _variant_pack, which is correct — cuDNN handles it as an internal virtual tensor. No binding is required for it. This is working correctly but is subtle; a code comment explaining that the intermediate dx does not need a binding would improve maintainability.

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant