Fix the retrieval of overwrite_main_grad#2764
Conversation
Greptile SummaryThis PR fixes the retrieval of
Confidence Score: 3/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A[Backward Pass] --> B{FP8/FP4 Quantization?}
B -->|Yes| C[weight = quantized FP8/FP4 tensor]
B -->|No| D[weight = original tensor]
C --> E[origin_weight = original tensor from ctx]
D --> E
E --> F["getattr(origin_weight, 'overwrite_main_grad', False)"]
F -->|True| G[accumulate = False]
F -->|False| H[accumulate = accumulate_wgrad_into_param_main_grad]
G --> I[WGRAD GEMM]
H --> I
Last reviewed commit: e0792f5 |
| "accumulate": ( | ||
| accumulate_wgrad_into_param_main_grad | ||
| if not getattr(fc1_weight, "overwrite_main_grad", False) | ||
| if not getattr(origin_fc1_weight, "overwrite_main_grad", False) |
There was a problem hiding this comment.
Swapped fc1/fc2 weight reference
This line checks origin_fc1_weight.overwrite_main_grad but it is inside the fc2 wgrad GEMM kwargs (note fc2_grad_weight_quantizer on line 1224 and origin_fc2_weight.main_grad on line 1231). It should be checking origin_fc2_weight instead.
Similarly, the fc1 wgrad section at line 1474 checks origin_fc2_weight but should check origin_fc1_weight.
This cross-reference was present in the original code (fc1_weight / fc2_weight were already swapped), but since this PR specifically targets fixing the overwrite_main_grad retrieval, it would be good to fix the swap at the same time.
| if not getattr(origin_fc1_weight, "overwrite_main_grad", False) | |
| if not getattr(origin_fc2_weight, "overwrite_main_grad", False) |
| "accumulate": ( | ||
| accumulate_wgrad_into_param_main_grad | ||
| if not getattr(fc2_weight, "overwrite_main_grad", False) | ||
| if not getattr(origin_fc2_weight, "overwrite_main_grad", False) |
There was a problem hiding this comment.
Swapped fc1/fc2 weight reference
Same issue as the fc2 wgrad section above: this checks origin_fc2_weight but it is inside the fc1 wgrad GEMM kwargs (note fc1_grad_weight_quantizer on line 1471 and origin_fc1_weight.main_grad on line 1478). It should be checking origin_fc1_weight.
| if not getattr(origin_fc2_weight, "overwrite_main_grad", False) | |
| if not getattr(origin_fc1_weight, "overwrite_main_grad", False) |
…weight instead of the fp8 (fp4) weight, since the fp8 (fp4) weight may not inherit the required attributes during creation. Signed-off-by: jianbinc <shjwudp@gmail.com>
Signed-off-by: jianbinc <shjwudp@gmail.com>
e0792f5 to
d8285e1
Compare
TE modules now fetch
overwrite_main_gradfrom the original weight instead of the FP8 (FP4) weight, since the FP8/FP4 weight might not inherit the required attributes during creation. This change fixes a potential issue whereoverwrite_main_gradwould not be applied correctly.Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: