Skip to content

Fix the retrieval of overwrite_main_grad#2764

Draft
shjwudp wants to merge 2 commits intoNVIDIA:mainfrom
shjwudp:fix_fp8_fp4_overwrite_main_grad
Draft

Fix the retrieval of overwrite_main_grad#2764
shjwudp wants to merge 2 commits intoNVIDIA:mainfrom
shjwudp:fix_fp8_fp4_overwrite_main_grad

Conversation

@shjwudp
Copy link
Contributor

@shjwudp shjwudp commented Mar 16, 2026

TE modules now fetch overwrite_main_grad from the original weight instead of the FP8 (FP4) weight, since the FP8/FP4 weight might not inherit the required attributes during creation. This change fixes a potential issue where overwrite_main_grad would not be applied correctly.

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

@shjwudp shjwudp marked this pull request as draft March 16, 2026 16:54
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 16, 2026

Greptile Summary

This PR fixes the retrieval of overwrite_main_grad in three TE modules (grouped_linear, layernorm_linear, layernorm_mlp) by reading the attribute from the original weight tensor instead of the FP8/FP4 quantized weight, which may not inherit custom attributes during creation.

  • grouped_linear.py: weights[0]origin_weights[0] — correct fix
  • layernorm_linear.py: weightorigin_weight — correct fix
  • layernorm_mlp.py: fc1_weight/fc2_weightorigin_fc1_weight/origin_fc2_weight — correct source switch, but the fc1/fc2 references appear swapped (fc2 wgrad checks fc1's attribute and vice versa, a pre-existing issue carried over into this PR)
  • Note: linear.py does not need this fix since its weight variable in the backward pass already refers to the original (non-FP8) weight

Confidence Score: 3/5

  • The core fix is correct and low-risk, but the swapped fc1/fc2 references in layernorm_mlp.py should be addressed before merging.
  • Two of the three files have clean, correct fixes. However, layernorm_mlp.py carries over a pre-existing bug where fc1 and fc2 weight references are swapped when checking overwrite_main_grad. While this may be masked in practice if both weights share the same attribute value, it is semantically incorrect and should be fixed in the same PR that specifically targets this attribute retrieval.
  • transformer_engine/pytorch/module/layernorm_mlp.py — the fc1/fc2 weight cross-reference at lines 1227 and 1474 should be corrected.

Important Files Changed

Filename Overview
transformer_engine/pytorch/module/grouped_linear.py Correctly changes weights[0] to origin_weights[0] for overwrite_main_grad lookup. Clean, straightforward fix.
transformer_engine/pytorch/module/layernorm_linear.py Correctly changes weight to origin_weight for overwrite_main_grad lookup. Clean, straightforward fix.
transformer_engine/pytorch/module/layernorm_mlp.py Switches to origin_fc1_weight/origin_fc2_weight as intended, but the fc1/fc2 weight references remain swapped from pre-existing bug — fc2 wgrad checks fc1's attribute and vice versa.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A[Backward Pass] --> B{FP8/FP4 Quantization?}
    B -->|Yes| C[weight = quantized FP8/FP4 tensor]
    B -->|No| D[weight = original tensor]
    C --> E[origin_weight = original tensor from ctx]
    D --> E
    E --> F["getattr(origin_weight, 'overwrite_main_grad', False)"]
    F -->|True| G[accumulate = False]
    F -->|False| H[accumulate = accumulate_wgrad_into_param_main_grad]
    G --> I[WGRAD GEMM]
    H --> I
Loading

Last reviewed commit: e0792f5

"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc1_weight, "overwrite_main_grad", False)
if not getattr(origin_fc1_weight, "overwrite_main_grad", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swapped fc1/fc2 weight reference

This line checks origin_fc1_weight.overwrite_main_grad but it is inside the fc2 wgrad GEMM kwargs (note fc2_grad_weight_quantizer on line 1224 and origin_fc2_weight.main_grad on line 1231). It should be checking origin_fc2_weight instead.

Similarly, the fc1 wgrad section at line 1474 checks origin_fc2_weight but should check origin_fc1_weight.

This cross-reference was present in the original code (fc1_weight / fc2_weight were already swapped), but since this PR specifically targets fixing the overwrite_main_grad retrieval, it would be good to fix the swap at the same time.

Suggested change
if not getattr(origin_fc1_weight, "overwrite_main_grad", False)
if not getattr(origin_fc2_weight, "overwrite_main_grad", False)

"accumulate": (
accumulate_wgrad_into_param_main_grad
if not getattr(fc2_weight, "overwrite_main_grad", False)
if not getattr(origin_fc2_weight, "overwrite_main_grad", False)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Swapped fc1/fc2 weight reference

Same issue as the fc2 wgrad section above: this checks origin_fc2_weight but it is inside the fc1 wgrad GEMM kwargs (note fc1_grad_weight_quantizer on line 1471 and origin_fc1_weight.main_grad on line 1478). It should be checking origin_fc1_weight.

Suggested change
if not getattr(origin_fc2_weight, "overwrite_main_grad", False)
if not getattr(origin_fc1_weight, "overwrite_main_grad", False)

shjwudp added 2 commits March 16, 2026 10:16
…weight instead of the fp8 (fp4) weight, since the fp8 (fp4) weight may not inherit the required attributes during creation.

Signed-off-by: jianbinc <shjwudp@gmail.com>
Signed-off-by: jianbinc <shjwudp@gmail.com>
@shjwudp shjwudp force-pushed the fix_fp8_fp4_overwrite_main_grad branch from e0792f5 to d8285e1 Compare March 16, 2026 17:18
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant