Skip to content

GEMM + Swiglu fused Grouped MLP for MXFP8#2769

Open
ksivaman wants to merge 9 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_grouped_mlp
Open

GEMM + Swiglu fused Grouped MLP for MXFP8#2769
ksivaman wants to merge 9 commits intoNVIDIA:mainfrom
ksivaman:fused_mxfp8_grouped_mlp

Conversation

@ksivaman
Copy link
Member

Description

Please include a brief summary of the changes, relevant motivation and context.

Fixes # (issue)

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

Please list the changes introduced in this PR:

  • Change A
  • Change B

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 17, 2026

Greptile Summary

This PR introduces a fused GEMM+SwiGLU forward and backward kernel for MXFP8 Grouped MLP on NVIDIA Blackwell GPUs (SM100+). It adds two new FusedOperation classes — ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 and BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8 — that call experimental CuTe DSL kernels from the cuDNN front-end to replace three separate ops (FC1 GroupedLinear + ScaledSwiGLU + FC2 GroupedLinear) with a single fused launch. The change also extends GroupedTensorStorage / GroupedTensor with a with_gemm_swizzled_scales field used to skip re-swizzling when scales are already in GEMM-optimal layout.

Key issues found:

  • Duplicate quantizer.set_usage call for FC2 non-grouped weights (forward_grouped_mlp.py lines 229–238): set_usage is called unconditionally before the if not is_quantized_tensor check and again inside it, inconsistently with the FC1 branch. This mutates quantizer state for already-quantized weights without applying quantization.
  • Missing validation that sum(split_sizes) % 128 == 0 (forward_grouped_mlp.py ~line 272, backward_grouped_mlp.py ~line 245): scale tensor views use integer division by 128 on the total token count without verifying divisibility, which would silently produce incorrect shapes or confusing runtime errors for odd split sizes.

Confidence Score: 3/5

  • Significant new functionality with a few correctness issues that should be addressed before merging.
  • The core fused kernel integration is well-structured and the new forward/backward ops follow existing patterns in the codebase. However, two issues reduce confidence: (1) a duplicate quantizer.set_usage call in the FC2 branch mutates quantizer state for already-quantized weights unexpectedly, and (2) there is no validation that the total token dimension is divisible by 128 before the scale tensor view operations, which can produce confusing failures for non-aligned inputs. Several issues noted in prior review threads (missing elif in backward _get_kernel_constants, hardcoded sf_vec_size=32, hard asserts on is_supported()) also remain unaddressed.
  • transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py and transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py need the most attention.

Important Files Changed

Filename Overview
transformer_engine/pytorch/ops/fused/forward_grouped_mlp.py New fused forward op (ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8) combining GroupedLinear + SwiGLU + GroupedLinear for MXFP8 on SM100+. Has a duplicate quantizer.set_usage call in the FC2 non-grouped-parameter branch and missing validation that sum(split_sizes) % 128 == 0 before scale tensor view operations.
transformer_engine/pytorch/ops/fused/backward_grouped_mlp.py New fused backward op (BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8) for dSwiGLU + FC2 dgrad + FC1 dgrad/wgrad. Missing validation that out_shape[0] % 128 == 0 before scale tensor reshaping; also missing the elif guard for norm_const_tensor in _get_kernel_constants (already noted in prior review threads); sf_vec_size hardcoded to 32 rather than MXFP8_BLOCK_SCALING_SIZE.
transformer_engine/pytorch/tensor/storage/grouped_tensor_storage.py Extended GroupedTensorStorage with with_gemm_swizzled_scales field and related initialization; unreachable assignment after return in new was noted in prior review threads.
tests/pytorch/test_fusible_ops.py New tests for MXFP8 fused grouped MLP (numerics and CUDA-graph-safe path). Lines 3461-3462 use hard assert on is_supported() inside the fusion check block instead of pytest.skip, which will cause test failures (not skips) on systems where MXFP8 is available but the specific fused kernel is not (already noted in prior threads).
transformer_engine/pytorch/ops/basic/grouped_linear.py Added assert guards with debug-style messages (!!!!) in fuser_forward; underlying logic is correct. Assert messages noted in prior threads.

Sequence Diagram

sequenceDiagram
    participant Input as Input (sum_m × K)
    participant GQ1 as group_quantize (FC1 input)
    participant FK1 as grouped_gemm_swiglu_kernel (SM100)
    participant GFC2 as general_grouped_gemm (FC2)
    participant Output as Output (sum_m × N)

    Input->>GQ1: MXFP8 quantize rowwise+colwise
    GQ1->>FK1: fc1_x_data, fc1_x_scales (permuted)
    Note over FK1: FC1 GEMM (MXFP8) +\nSwiGLU + post-scale\n(discrete_col_sfd=True)
    FK1-->>FK1: c_tensor (BF16, swiglu_in saved for bwd)
    FK1-->>FK1: d_tensor + sfd_row + sfd_col (FP8 + MXFP8 scales)
    FK1->>GFC2: grouped_fc2_x (pre-swizzled scales)
    GFC2->>Output: FC2 GEMM (MXFP8, layout=TN)

    Note over Input,Output: Backward pass (BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8)
    Output->>GQ1: group_quantize (FC2 grad output)
    GQ1->>FK1: grouped_gemm_dswiglu_kernel (SM100)
    Note over FK1: FC2 dgrad GEMM +\ndSwiGLU + grad_scales
    FK1-->>GFC2: grouped_fc1_dy (FP8 + MXFP8 scales)
    GFC2->>Input: FC1 dgrad GEMM (layout=NN)
    GFC2-->>GFC2: FC1/FC2 wgrad GEMMs (layout=NT)
Loading

Last reviewed commit: bf7af9f

import os
import functools
import math
from pickle import TRUE
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unused import from pickle import TRUE

TRUE is imported from the pickle module but is never used anywhere in this file. pickle.TRUE is an internal pickle opcode byte string (b'I01\n'), not a Python True value. This is almost certainly a leftover from development and should be removed.

Suggested change
from pickle import TRUE
from typing import Optional

Comment on lines +202 to +205
return instance

self.with_gemm_swizzled_scales = with_gemm_swizzled_scales

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Unreachable code after return

The line self.with_gemm_swizzled_scales = with_gemm_swizzled_scales is placed after return instance in __new__, so it will never execute. The intent seems to be handled already in _initialize_storage_fields (where instance.with_gemm_swizzled_scales = with_gemm_swizzled_scales is correctly set), so this line is both unreachable and redundant.

Suggested change
return instance
self.with_gemm_swizzled_scales = with_gemm_swizzled_scales
return instance
def has_data(self) -> bool:

Comment on lines +484 to +500
global global_alpha_tensor
alpha_tensor = self._mxfp8_alpha_tensor
norm_const_tensor = self._mxfp8_norm_const_tensor
if (
alpha_tensor is None
or alpha_tensor.numel() != num_groups
or alpha_tensor.dtype != dtype
or alpha_tensor.device != device
):
if global_alpha_tensor is None:
global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device)
alpha_tensor = global_alpha_tensor
norm_const_tensor = alpha_tensor[:1]
self._mxfp8_alpha_tensor = alpha_tensor
self._mxfp8_norm_const_tensor = norm_const_tensor
elif (
norm_const_tensor is None
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 global_alpha_tensor stale for multiple instances with different num_groups

In _get_kernel_constants, the module-level global_alpha_tensor is created only once (when it is None). If a second ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8 instance is created with a different num_groups, the stale global (wrong size) is assigned to self._mxfp8_alpha_tensor. On every subsequent call, the condition alpha_tensor.numel() != num_groups will remain True, and the wrong-sized tensor from the global will keep being assigned—without ever being recreated.

Concretely:

  1. Instance A (4 groups): creates global_alpha_tensor with 4 elements ✓
  2. Instance B (8 groups): global_alpha_tensor is None is False → skips creation → assigns the 4-element tensor as alpha_tensor → wrong size passed to the kernel.

The fix is to always recreate the global when the cached size/dtype/device doesn't match:

if (
    global_alpha_tensor is None
    or global_alpha_tensor.numel() != num_groups
    or global_alpha_tensor.dtype != dtype
    or global_alpha_tensor.device != device
):
    global_alpha_tensor = torch.ones(num_groups, dtype=dtype, device=device)
alpha_tensor = global_alpha_tensor

The same issue exists in backward_grouped_mlp.py's _get_kernel_constants.

Comment on lines +541 to +542
assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!"
assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Debug assertion messages should use proper error messages

These assert statements with !!!! in the messages look like debugging leftovers. While the assertions themselves are reasonable sanity checks, the message style is not production-appropriate. Consider using RuntimeError or a cleaner assert message:

Suggested change
assert hasattr(weight_param, "main_grad"), "MAIN GRAD NOT FOUND !!!!"
assert weight_param.main_grad is not None, "MAIN GRAD IS NONE !!!!"
if self._accumulate_into_main_grad:
if not hasattr(weight_param, "main_grad"):
raise RuntimeError(
f"Expected 'main_grad' attribute on weight parameter, but it was not found."
)
if weight_param.main_grad is None:
raise RuntimeError(
f"'main_grad' on weight parameter is None."
)

Note: If this suggestion doesn't match your team's coding style, reply to this and let me know. I'll remember it for next time!

norm_const_tensor=norm_const_tensor,
d_dtype=torch.float8_e4m3fn,
cd_major="n",
sf_vec_size=32,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 sf_vec_size hardcoded instead of using the shared constant

The backward kernel call uses sf_vec_size=32 hardcoded, while the forward kernel in forward_grouped_mlp.py correctly uses sf_vec_size=MXFP8_BLOCK_SCALING_SIZE (which equals 32). Both values happen to be the same today, but for consistency and maintainability, the backward should also import and use MXFP8_BLOCK_SCALING_SIZE.

Suggested change
sf_vec_size=32,
sf_vec_size=MXFP8_BLOCK_SCALING_SIZE,

(add from ...constants import MXFP8_BLOCK_SCALING_SIZE to the imports)

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman requested a review from vthumbe1503 March 17, 2026 04:55
ksivaman and others added 4 commits March 16, 2026 21:59
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
@ksivaman ksivaman marked this pull request as ready for review March 17, 2026 05:38
Comment on lines +3617 to +3637
dtype=torch.float32,
)
fc1_weight.main_grad.fill_(value)
fc2_weight.main_grad.fill_(value)

def _collect_main_grads() -> tuple[torch.Tensor, torch.Tensor]:
if single_grouped_parameter:
fc1_main_grad = fc1.weight.main_grad.detach().clone()
fc2_main_grad = fc2.weight.main_grad.detach().clone()
else:
fc1_main_grad = torch.stack(
[
getattr(fc1, f"weight{group_idx}").main_grad.detach().clone()
for group_idx in range(group_size)
],
dim=0,
)
fc2_main_grad = torch.stack(
[
getattr(fc2, f"weight{group_idx}").main_grad.detach().clone()
for group_idx in range(group_size)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hard assert on is_supported() causes test failure instead of skip

assert te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported() will raise AssertionError on any Blackwell system where the cudnn front-end package is installed for MXFP8 quantization but the specific grouped-GEMM-SwiGLU kernel is not available. In that scenario the test should be skipped, not failed. The maybe_skip_quantization guard only checks whether MXFP8 quantization in general is supported — it does not gate on the availability of the fused kernel.

Consider wrapping the fusion-presence block in a skip guard:

if (
    quantization == "mxfp8"
    and dtype in (torch.bfloat16, torch.float16)
    and not bias
    and glu_interleave_size == 32
):
    if not te_ops.fused.ForwardGroupedMLP_CuTeGEMMSwiGLU_MXFP8.is_supported():
        pytest.skip("MXFP8 fused grouped MLP kernel not available on this system")
    forward_ops = module._module_groups[0]._forward_ops
    ...

Comment on lines +722 to +742
window = [op]
else:
# Shift window if window doesn't match pattern
out.extend(window[:-2])
window = window[-2:]

# Adjust window to expected size
out.extend(window[:-3])
window = window[-3:]
while ops and len(window) < 3:
window.append(ops[0])
ops = ops[1:]

# Return list of ops
out.extend(window)
return out


# Register fusion if available
if BackwardGroupedMLP_CuTeGEMMDSwiGLU_MXFP8.is_supported():
register_backward_fusion(fuse_backward_ops, prepend=True)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 norm_const_tensor not refreshed when only alpha_tensor is stale

The backward's _get_kernel_constants is missing the elif guard that the forward version has. If self._mxfp8_alpha_tensor is somehow valid but self._mxfp8_norm_const_tensor is None (e.g., after a partial object state restore or future refactoring), the function returns None for norm_const_tensor, which will crash the kernel call.

The forward version correctly handles this with:

elif (
    norm_const_tensor is None
    or norm_const_tensor.numel() != 1
    or norm_const_tensor.dtype != dtype
    or norm_const_tensor.device != device
):
    norm_const_tensor = alpha_tensor[:1]
    self._mxfp8_norm_const_tensor = norm_const_tensor

Add the same defensive elif branch here for consistency and safety.

Signed-off-by: Kirthi Shankar Sivamani <ksivamani@nvidia.com>
Comment on lines +267 to +280
fc1_x_data = grouped_fc1_x.rowwise_data.view(in_shape[0], in_shape[1])
fc1_x_data = fc1_x_data.view(dtype=torch.float8_e4m3fn)
fc1_x_data = fc1_x_data.unsqueeze(0).permute(1, 2, 0)
fc1_x_scales = grouped_fc1_x.scale_inv
fc1_x_scales = fc1_x_scales.view(dtype=torch.float8_e8m0fnu)
fc1_x_scales = fc1_x_scales.view(
1,
in_shape[0] // 128,
in_shape[1] // 128,
32,
4,
4,
)
fc1_x_scales = fc1_x_scales.permute(3, 4, 1, 5, 2, 0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No validation that total token count is divisible by 128

The scale tensor view at lines 272–279 uses integer division in_shape[0] // 128 to reshape the MXFP8 scale buffer. If in_shape[0] (i.e., sum(split_sizes)) is not divisible by 128, the view shape product will not match the actual buffer size and either produce incorrect behavior (wrong permute dimensions) or a runtime error with a confusing message.

The constructor checks that in_features % 256 == 0 and out_features % 256 == 0, but nothing validates that the token dimension sum(split_sizes) is divisible by 128 (required by the MXFP8 block-scaling layout). A user passing split sizes like [64, 65] would hit this silently.

The same assumption appears in the backward pass at backward_grouped_mlp.py lines 243–250.

Consider adding a guard before the view:

if in_shape[0] % 128 != 0:
    raise ValueError(
        f"Total token count must be divisible by 128 for MXFP8 fused kernel, "
        f"but got sum(split_sizes)={in_shape[0]}."
    )

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant