GEMM + Swiglu fused Grouped MLP for MXFP8#2769
Conversation
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Greptile SummaryThis PR introduces a fused GEMM+SwiGLU forward and backward kernel for MXFP8 Grouped MLP on NVIDIA Blackwell GPUs (SM100+). It adds two new Key issues found:
Confidence Score: 3/5
Important Files Changed
Sequence DiagramsequenceDiagram
participant Input as Input (sum_m × K)
participant GQ1 as group_quantize (FC1 input)
participant FK1 as grouped_gemm_swiglu_kernel (SM100)
participant GFC2 as general_grouped_gemm (FC2)
participant Output as Output (sum_m × N)
Input->>GQ1: MXFP8 quantize rowwise+colwise
GQ1->>FK1: fc1_x_data, fc1_x_scales (permuted)
Note over FK1: FC1 GEMM (MXFP8) +\nSwiGLU + post-scale\n(discrete_col_sfd=True)
FK1-->>FK1: c_tensor (BF16, swiglu_in saved for bwd)
FK1-->>FK1: d_tensor + sfd_row + sfd_col (FP8 + MXFP8 scales)
FK1->>GFC2: grouped_fc2_x (pre-swizzled scales)
GFC2->>Output: FC2 GEMM (MXFP8, layout=TN)
Note over Input,Output: Backward pass (BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8)
Output->>GQ1: group_quantize (FC2 grad output)
GQ1->>FK1: grouped_gemm_dswiglu_kernel (SM100)
Note over FK1: FC2 dgrad GEMM +\ndSwiGLU + grad_scales
FK1-->>GFC2: grouped_fc1_dy (FP8 + MXFP8 scales)
GFC2->>Input: FC1 dgrad GEMM (layout=NN)
GFC2-->>GFC2: FC1/FC2 wgrad GEMMs (layout=NT)
Last reviewed commit: bf7af9f |
| import os | ||
| import functools | ||
| import math | ||
| from pickle import TRUE |
There was a problem hiding this comment.
Unused import
from pickle import TRUE
TRUE is imported from the pickle module but is never used anywhere in this file. pickle.TRUE is an internal pickle opcode byte string (b'I01\n'), not a Python True value. This is almost certainly a leftover from development and should be removed.
| from pickle import TRUE | |
| from typing import Optional |
| return instance | ||
|
|
||
| self.with_gemm_swizzled_scales = with_gemm_swizzled_scales | ||
|
|
There was a problem hiding this comment.
The line self.with_gemm_swizzled_scales = with_gemm_swizzled_scales is placed after return instance in __new__, so it will never execute. The intent seems to be handled already in _initialize_storage_fields (where instance.with_gemm_swizzled_scales = with_gemm_swizzled_scales is correctly set), so this line is both unreachable and redundant.
| return instance | |
| self.with_gemm_swizzled_scales = with_gemm_swizzled_scales | |
| return instance | |
| def has_data(self) -> bool: |
| global global_alpha_tensor | ||
| alpha_tensor = self._mxfp8_alpha_tensor | ||
| norm_const_tensor = self._mxfp8_norm_const_tensor | ||
| if ( | ||
| alpha_tensor is None | ||
| or alpha_tensor.numel() != num_groups | ||
| or alpha_tensor.dtype != dtype | ||
| or alpha_tensor.device != device | ||
| ): | ||
| if global_alpha_tensor is None: | ||
| global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device) | ||
| alpha_tensor = global_alpha_tensor | ||
| norm_const_tensor = alpha_tensor[:1] | ||
| self._mxfp8_alpha_tensor = alpha_tensor | ||
| self._mxfp8_norm_const_tensor = norm_const_tensor | ||
| elif ( | ||
| norm_const_tensor is None |
There was a problem hiding this comment.
global_alpha_tensor stale for multiple instances with different num_groups
In _get_kernel_constants, the module-level global_alpha_tensor is created only once (when it is None). If a second ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 instance is created with a different num_groups, the stale global (wrong size) is assigned to self._mxfp8_alpha_tensor. On every subsequent call, the condition alpha_tensor.numel() != num_groups will remain True, and the wrong-sized tensor from the global will keep being assigned—without ever being recreated.
Concretely:
- Instance A (4 groups): creates
global_alpha_tensorwith 4 elements ✓ - Instance B (8 groups):
global_alpha_tensor is NoneisFalse→ skips creation → assigns the 4-element tensor asalpha_tensor→ wrong size passed to the kernel.
The fix is to always recreate the global when the cached size/dtype/device doesn't match:
if (
global_alpha_tensor is None
or global_alpha_tensor.numel() != num_groups
or global_alpha_tensor.dtype != dtype
or global_alpha_tensor.device != device
):
global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device)
alpha_tensor = global_alpha_tensorThe same issue exists in backward_grouped_mlp.py's _get_kernel_constants.
| assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!" | ||
| assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!" |
There was a problem hiding this comment.
Debug assertion messages should use proper error messages
These assert statements with !!!! in the messages look like debugging leftovers. While the assertions themselves are reasonable sanity checks, the message style is not production-appropriate. Consider using RuntimeError or a cleaner assert message:
| assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!" | |
| assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!" | |
| if self._accumulate_into_main_grad: | |
| if not hasattr(weight_param, "main_grad"): | |
| raise RuntimeError( | |
| f"Expected 'main_grad' attribute on weight parameter, but it was not found." | |
| ) | |
| if weight_param.main_grad is None: | |
| raise RuntimeError( | |
| f"'main_grad' on weight parameter is None." | |
| ) |
Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!
| norm_const_tensor=norm_const_tensor, | ||
| d_dtype=torch.float8_e4m3fn, | ||
| cd_major="n", | ||
| sf_vec_size=32, |
There was a problem hiding this comment.
sf_vec_size hardcoded instead of using the shared constant
The backward kernel call uses sf_vec_size=32 hardcoded, while the forward kernel in forward_grouped_mlp.py correctly uses sf_vec_size=MXFP8_BLOCK_SCALING_SIZE (which equals 32). Both values happen to be the same today, but for consistency and maintainability, the backward should also import and use MXFP8_BLOCK_SCALING_SIZE.
| sf_vec_size=32, | |
| sf_vec_size=MXFP8_BLOCK_SCALING_SIZE, |
(add from ...constants import MXFP8_BLOCK_SCALING_SIZE to the imports)
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
for more information, see https://pre-commit.ci
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
for more information, see https://pre-commit.ci
| dtype=torch.float32, | ||
| ) | ||
| fc1_weight.main_grad.fill_(value) | ||
| fc2_weight.main_grad.fill_(value) | ||
|
|
||
| def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]: | ||
| if single_grouped_parameter: | ||
| fc1_main_grad = fc1.weight.main_grad.detach().clone() | ||
| fc2_main_grad = fc2.weight.main_grad.detach().clone() | ||
| else: | ||
| fc1_main_grad = torch.stack( | ||
| [ | ||
| getattr(fc1, f"weight{group_idx}").main_grad.detach().clone() | ||
| for group_idx in range(group_size) | ||
| ], | ||
| dim=0, | ||
| ) | ||
| fc2_main_grad = torch.stack( | ||
| [ | ||
| getattr(fc2, f"weight{group_idx}").main_grad.detach().clone() | ||
| for group_idx in range(group_size) |
There was a problem hiding this comment.
Hard assert on
is_supported() causes test failure instead of skip
assert te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() will raise AssertionError on any Blackwell system where the cudnn front-end package is installed for MXFP8 quantization but the specific grouped-GEMM-SwiGLU kernel is not available. In that scenario the test should be skipped, not failed. The maybe_skip_quantization guard only checks whether MXFP8 quantization in general is supported — it does not gate on the availability of the fused kernel.
Consider wrapping the fusion-presence block in a skip guard:
if (
quantization == "mxfp8"
and dtype in (torch.bfloat16, torch.float16)
and not bias
and glu_interleave_size == 32
):
if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
pytest.skip("MXFP8 fused grouped MLP kernel not available on this system")
forward_ops = module._module_groups[0]._forward_ops
...| window = [op] | ||
| else: | ||
| # Shift window if window doesn't match pattern | ||
| out.extend(window[:-2]) | ||
| window = window[-2:] | ||
|
|
||
| # Adjust window to expected size | ||
| out.extend(window[:-3]) | ||
| window = window[-3:] | ||
| while ops and len(window) < 3: | ||
| window.append(ops[0]) | ||
| ops = ops[1:] | ||
|
|
||
| # Return list of ops | ||
| out.extend(window) | ||
| return out | ||
|
|
||
|
|
||
| # Register fusion if available | ||
| if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported(): | ||
| register_backward_fusion(fuse_backward_ops, prepend=True) |
There was a problem hiding this comment.
norm_const_tensor not refreshed when only alpha_tensor is stale
The backward's _get_kernel_constants is missing the elif guard that the forward version has. If self._mxfp8_alpha_tensor is somehow valid but self._mxfp8_norm_const_tensor is None (e.g., after a partial object state restore or future refactoring), the function returns None for norm_const_tensor, which will crash the kernel call.
The forward version correctly handles this with:
elif (
norm_const_tensor is None
or norm_const_tensor.numel() != 1
or norm_const_tensor.dtype != dtype
or norm_const_tensor.device != device
):
norm_const_tensor = alpha_tensor[:1]
self._mxfp8_norm_const_tensor = norm_const_tensorAdd the same defensive elif branch here for consistency and safety.
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
| fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1]) | ||
| fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn) | ||
| fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0) | ||
| fc1_x_scales = grouped_fc1_x.scale_inv | ||
| fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu) | ||
| fc1_x_scales = fc1_x_scales.view( | ||
| 1, | ||
| in_shape[0] // 128, | ||
| in_shape[1] // 128, | ||
| 32, | ||
| 4, | ||
| 4, | ||
| ) | ||
| fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0) |
There was a problem hiding this comment.
No validation that total token count is divisible by 128
The scale tensor view at lines 272–279 uses integer division in_shape[0] // 128 to reshape the MXFP8 scale buffer. If in_shape[0] (i.e., sum(split_sizes)) is not divisible by 128, the view shape product will not match the actual buffer size and either produce incorrect behavior (wrong permute dimensions) or a runtime error with a confusing message.
The constructor checks that in_features % 256 == 0 and out_features % 256 == 0, but nothing validates that the token dimension sum(split_sizes) is divisible by 128 (required by the MXFP8 block-scaling layout). A user passing split sizes like [64, 65] would hit this silently.
The same assumption appears in the backward pass at backward_grouped_mlp.py lines 243–250.
Consider adding a guard before the view:
if in_shape[0] % 128 != 0:
raise ValueError(
f"Total token count must be divisible by 128 for MXFP8 fused kernel, "
f"but got sum(split_sizes)={in_shape[0]}."
)
Description
Please include a brief summary of the changes, relevant motivation and context.
Fixes # (issue)
Type of change
Changes
Please list the changes introduced in this PR:
Checklist: