Enable fused RMSNorm dLN + add through CUDNN#2778
Enable fused RMSNorm dLN + add through CUDNN#2778CarlosGomes98 wants to merge 7 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR enables the cuDNN backend for Key changes:
Confidence Score: 4/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Caller
participant rmsnorm_bwd_add
participant NormPlanRegistry
participant CudnnNormalizationPlan
participant cuDNN
Caller->>rmsnorm_bwd_add: dz, x, add, rsigma, gamma, dx, dgamma
rmsnorm_bwd_add->>rmsnorm_bwd_add: validate dtypes & shapes
alt use_cudnn_norm_bwd()
rmsnorm_bwd_add->>NormPlanRegistry: getNormalizationPlan(Cudnn, RMSNorm, BackwardAdd, ...)
NormPlanRegistry->>CudnnNormalizationPlan: construct (wtype, itype, otype)
CudnnNormalizationPlan->>CudnnNormalizationPlan: NVTE_CHECK(cudnnGetVersion >= 92100)
CudnnNormalizationPlan->>cuDNN: rmsnorm_backward(dz, x, gamma, rsigma)
CudnnNormalizationPlan->>cuDNN: pointwise ADD(_dx, _add) → _dx_with_add
CudnnNormalizationPlan->>cuDNN: _build() / check_support / build_plans
NormPlanRegistry-->>rmsnorm_bwd_add: plan (cached)
rmsnorm_bwd_add->>CudnnNormalizationPlan: execute(x, gamma, rsigma, dx, dz, add, dgamma)
CudnnNormalizationPlan->>CudnnNormalizationPlan: bind variant_pack (_x, _rsigma, _dz, _dgamma, _dx=_dx_with_add, _add)
CudnnNormalizationPlan->>cuDNN: graph.execute(variant_pack, workspace)
cuDNN-->>Caller: dx (= rmsnorm_dx + add), dgamma
else TE backend
rmsnorm_bwd_add->>NormPlanRegistry: getNormalizationPlan(Te, RMSNorm, BackwardAdd, ...)
NormPlanRegistry-->>rmsnorm_bwd_add: plan (cached)
rmsnorm_bwd_add->>Caller: dx, dgamma (TE kernel)
end
Last reviewed commit: "remove dangling todo" |
| if (_norm_stage == NVTE_Norm_Stage::BackwardAdd) { | ||
| NVTE_CHECK(cudnnGetVersion() >= 92100, | ||
| "Fused BackwardAdd requires cuDNN >= 9.21.0, but found ", cudnnGetVersion()); |
There was a problem hiding this comment.
Version check throws inside constructor with no fallback path
The NVTE_CHECK(cudnnGetVersion() >= 92100, ...) fires during CudnnNormalizationPlan construction, which is called from getNormalizationPlan. If a user sets NVTE_NORM_BWD_USE_CUDNN=1 on a system with cuDNN < 9.21, the entire rmsnorm_bwd_add call will throw a hard error rather than gracefully falling back to the TE kernel. Unlike typical cuDNN support checks elsewhere in the codebase that gate entry into the cuDNN path, this check is buried inside plan construction where there is no recovery mechanism.
Consider validating the cuDNN version at the top of rmsnorm_bwd_add in rmsnorm_api.cpp (inside the use_cudnn_norm_bwd() branch) and falling back to the TE backend when the version requirement isn't met, consistent with how other backend-selection logic is structured.
There was a problem hiding this comment.
Not sure what the best behaviour here is. If NVTE_NORM_BWD_USE_CUDNN=1 is set, and we cannot use CUDNN, I would expect it to fail early, and for the user to remove that env var, rather than silently falling back to TE.
There was a problem hiding this comment.
You're right – if a user explicitly sets NVTE_NORM_BWD_USE_CUDNN=1, they're requesting a specific backend and should expect it to fail loudly rather than silently fall back. The explicit opt-in makes the current behavior reasonable.
One clarification: does the regular (non-BackwardAdd) cuDNN normalization work with cuDNN < 9.21? If so, a user might set NVTE_NORM_BWD_USE_CUDNN=1 for general cuDNN norm support and be surprised when BackwardAdd specifically fails. In that case, the error message could be slightly more actionable:
"Fused BackwardAdd requires cuDNN >= 9.21.0 (found " + std::to_string(cudnnGetVersion()) +
"). Either upgrade cuDNN or unset NVTE_NORM_BWD_USE_CUDNN to use the TE backend."But if BackwardAdd is only used when explicitly requested (not auto-selected based on the operation pattern), then the current behavior and message are fine as-is.
| _variant_pack = { | ||
| {_x, x_dptr}, {_rsigma, rsigma_dptr}, {_dz, dz_dptr}, {_dgamma, dgamma_dptr}, {_dx, dx_dptr}}; |
There was a problem hiding this comment.
_dgamma mapped in variant pack but not an output tensor for BackwardAdd stage
In the variant pack initialization {_dgamma, dgamma_dptr} is always included. For the BackwardAdd stage, _dgamma->set_output(true) is called from line 417 (_dgamma->set_output(true).set_data_type(...)), so cuDNN will write to dgamma_dptr. This looks fine.
However, a concern worth verifying: in the BackwardAdd path, _dx (the member) has been reassigned to _dx_with_add. The original intermediate rmsnorm _dx tensor (the input to the pointwise add) was set to set_output(false). It is not in the _variant_pack, which is correct — cuDNN handles it as an internal virtual tensor. No binding is required for it. This is working correctly but is subtle; a code comment explaining that the intermediate dx does not need a binding would improve maintainability.
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
cc7931a to
a16b8d8
Compare
Signed-off-by: CarlosGomes98 <carlosmiguel.gomes@live.com.pt>
6d764d9 to
71db968
Compare
Description
From CUDNN 9.21, CUDNN introduces kernel support for fusing the dLN + add backward pattern, useful for the residual fork + rmsnorm pattern in networks.
The fusion pattern is already possible through te.ops.Sequential with backward_add_rmsnorm.py, this just enables the CUDNN backend.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: