From 28e5f5377c3059c40c73b7c30b0ea584c9ea6943 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Mon, 9 Mar 2026 15:42:48 -0700 Subject: [PATCH 01/23] Refactor to group_sizes per tensor Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 16 +- transformer_engine/jax/cpp_extensions/gemm.py | 202 +++++++++++------- .../jax/csrc/extensions/gemm.cpp | 194 +++++++++++------ transformer_engine/jax/dense.py | 28 ++- 4 files changed, 284 insertions(+), 156 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..02cc05649a 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1787,13 +1787,16 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm + empty_gs = jnp.empty((0,), jnp.int32) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( lhs, rhs, - group_sizes, - contracting_dims, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1825,8 +1828,15 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + empty_gs = jnp.empty((0,), jnp.int32) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + lhs, + rhs, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, ) allclose_dtype = jnp.float8_e4m3fn diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ab2be7f799..c298e19bf0 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1446,12 +1446,12 @@ class GroupedGemmPrimitive(BasePrimitive): Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, group_offset, unused_placeholder + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (10, 11, 12, 13, 14, 15, 16) inner_primitive = None outer_primitive = None @@ -1462,17 +1462,15 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - group_sizes_aval, + lhs_group_sizes_aval, + rhs_group_sizes_aval, + out_group_sizes_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1480,35 +1478,57 @@ def abstract( Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_data: Left-hand side input matrix data, 2D array [rows, cols] lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_data: Right-hand side input matrix data, 2D array [rows, cols] rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group + lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_group_sizes: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR * alpha: 1D array of shape (G,) containing alpha values for each group * beta: 1D array of shape (G,) containing beta values for each group - M: Number of rows in the output matrix - N: Number of columns in the output matrix - K: Number of columns in the left-hand side matrix lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided - is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation - where both lhs and rhs are 2D matrices and output is (G, M, N) Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval - del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + del bias_aval + del has_bias, use_async_d2h_group_sizes + + # Determine mode from which group_sizes buffer is non-empty + is_wgrad = rhs_group_sizes_aval.size > 0 + num_groups = ( + lhs_group_sizes_aval.size + or rhs_group_sizes_aval.size + or out_group_sizes_aval.size + or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 + ) - num_groups = group_sizes_aval.size + # lhs_data_aval and rhs_data_aval are now 2D; derive output shape from buffer dims + if is_wgrad: + # lhs shape [K_lhs, M] (lhs_is_trans=True) or [M, K_lhs] (lhs_is_trans=False) + # M is the non-contracting (output) dim + M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] + N = rhs_data_aval.shape[1] + out_shape = (num_groups, M, N) + else: + # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) + # dim[0] is always total M for fwd/dgrad + M = lhs_data_aval.shape[0] + N = ( + rhs_data_aval.shape[1] + if not rhs_is_trans + else rhs_data_aval.shape[0] // num_groups + ) + out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1519,9 +1539,6 @@ def abstract( dtype=jnp.uint8, ) - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (num_groups, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) if use_v2_ffi: @@ -1597,15 +1614,11 @@ def outer_abstract(*args, **kwargs): def lowering( ctx, *args, - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1615,26 +1628,18 @@ def lowering( return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, ) @@ -1645,18 +1650,16 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_group_sizes, + rhs_group_sizes, + out_group_sizes, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, ): @@ -1671,17 +1674,15 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_group_sizes, + rhs_group_sizes, + out_group_sizes, *additional_args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) @@ -2022,10 +2023,24 @@ def _can_use_v2_grouped_gemm( return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias +def _flatten_to_2d(data, flatten_axis): + """Reshape *data* to 2D by splitting at *flatten_axis*. + + Positive flatten_axis: split before that axis index. + Negative flatten_axis: split before (ndim + flatten_axis). + """ + if data.ndim == 2: + return data # Already 2D, no reshape needed + fa = flatten_axis if flatten_axis >= 0 else data.ndim + flatten_axis + return data.reshape(math.prod(data.shape[:fa]), math.prod(data.shape[fa:])) + + def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - group_sizes: jnp.ndarray, + lhs_group_sizes: jnp.ndarray = None, # (G,) int32 if lhs first-dim is ragged, else None/(0,) + rhs_group_sizes: jnp.ndarray = None, # (G,) int32 if rhs first-dim is ragged (wgrad), else None/(0,) + out_group_sizes: jnp.ndarray = None, # (G,) int32 if output first-dim is ragged, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -2040,7 +2055,9 @@ def grouped_gemm( Args: lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - group_sizes: 1D array containing the sizes of each group + lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else None or empty (0,) sentinel + rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad mode), else None/(0,) + out_group_sizes: (G,) int32 if output first-dim is ragged, else None/(0,) contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -2060,6 +2077,15 @@ def grouped_gemm( # TODO(Phuong): implement the precision del precision + # Replace None sentinels with empty (0,) int32 arrays. + empty_gs = jnp.empty((0,), jnp.int32) + if lhs_group_sizes is None: + lhs_group_sizes = empty_gs + if rhs_group_sizes is None: + rhs_group_sizes = empty_gs + if out_group_sizes is None: + out_group_sizes = empty_gs + if isinstance(lhs, jnp.ndarray): assert isinstance(rhs, jnp.ndarray) out_dtype = lhs.dtype @@ -2074,8 +2100,14 @@ def grouped_gemm( out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape rhs_shape = rhs.original_shape - lhs_data = lhs.data - rhs_data = rhs.data + lhs_fa = lhs.flatten_axis + rhs_fa = rhs.flatten_axis + lhs_data = lhs.data.reshape( + math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:]) + ) + rhs_data = rhs.data.reshape( + math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:]) + ) lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv assert lhs.scaling_mode == rhs.scaling_mode @@ -2094,14 +2126,9 @@ def grouped_gemm( rhs_is_trans = rhs_contract_dim[0] != 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - is_grouped_dense_wgrad = False - if len(rhs_shape) == 2: - rhs_is_trans = rhs_contract_dim[0] != 0 - is_grouped_dense_wgrad = True - - # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? + # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? if ( - is_grouped_dense_wgrad + rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged and not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) ): @@ -2110,6 +2137,15 @@ def grouped_gemm( lhs_flatten_axis = 1 rhs_flatten_axis = 1 + # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, + # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. + if ( + rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged + and isinstance(lhs, GroupedScaledTensor1x) + and scaling_mode.is_1d_block_scaling() + ): + rhs_is_trans = False + if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2132,16 +2168,30 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + active_group_sizes = lhs_group_sizes if lhs_group_sizes.size > 0 else rhs_group_sizes + lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data + # grouped_quantize returns a 1D flat buffer; reshape to 2D using the + # original_shape and flatten_axis stored in each quantized tensor. + lhs_fa = lhs_q.flatten_axis # positive index (adjusted in create_1x) + rhs_fa = rhs_q.flatten_axis + lhs_data = lhs_q.data.reshape( + math.prod(lhs_q.original_shape[:lhs_fa]), + math.prod(lhs_q.original_shape[lhs_fa:]), + ) + rhs_data = rhs_q.data.reshape( + math.prod(rhs_q.original_shape[:rhs_fa]), + math.prod(rhs_q.original_shape[rhs_fa:]), + ) lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape + # Data is already 2D; reset flatten axes so _flatten_to_2d calls below are no-ops. + lhs_flatten_axis = -1 + rhs_flatten_axis = -1 assert not ( lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2 @@ -2172,31 +2222,26 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if group_sizes.size == rhs_shape[0]: + if lhs_group_sizes.size > 0: # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Calling GroupedGEMM Custom Call - K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) - K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - assert K_lhs == K_rhs - M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G + # Reshape inputs to 2D using the already-computed flatten_axes. + lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) + rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) - if is_grouped_dense_wgrad: - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) - else: - assert group_sizes.size == rhs_shape[0] + num_gemms = lhs_group_sizes.size or rhs_group_sizes.size or out_group_sizes.size has_bias = bias is not None if has_bias: + N_dim = rhs_data_2d.shape[0] // num_gemms if rhs_is_trans else rhs_data_2d.shape[1] assert bias.shape == ( - group_sizes.size, - N, - ), f"bias shape {bias.shape} does not match expected shape {(group_sizes.size, N)}" + num_gemms, + N_dim, + ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" bias = jnp.empty((), jnp.float32) if bias is None else bias assert group_offset is None, ( @@ -2207,7 +2252,6 @@ def grouped_gemm( use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) if use_v2_ffi: - num_gemms = group_sizes.shape[0] additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta else: @@ -2215,23 +2259,21 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data, + lhs_data_2d, lhs_scale_inv, - rhs_data, + rhs_data_2d, rhs_scale_inv, bias, - group_sizes, + lhs_group_sizes, + rhs_group_sizes, + out_group_sizes, additional_arg_0, additional_arg_1, - M=M, - N=N, - K=K_lhs, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4cbec405a4..834a7b9a5f 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -562,11 +562,12 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, + Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, + Buffer_Type out_group_sizes, Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, size_t m, - size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode, bool is_grouped_dense_wgrad) { + Result_Type setup_workspace, Result_Type int64_workspace, + bool lhs_is_trans, bool rhs_is_trans, + JAXX_Scaling_Mode scaling_mode) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -581,6 +582,40 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. + // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here + // as the output shape is inferred from lhs/rhs dims and passed to nvte_grouped_gemm implicitly. + (void)out_group_sizes; + + // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). + bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; + bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_ragged) + num_gemms = lhs_group_sizes.dimensions()[0]; + else if (is_rhs_ragged) + num_gemms = rhs_group_sizes.dimensions()[0]; + else if (out_group_sizes.element_count() > 0) + num_gemms = out_group_sizes.dimensions()[0]; + else + num_gemms = alpha.element_count(); // batched: no ragged tensor + const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + + // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. + NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); + NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); + size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + size_t m, n; + if (is_rhs_ragged) { + // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; + n = rhs_data.dimensions()[1]; + } else { + m = lhs_data.dimensions()[0]; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -594,14 +629,15 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - - // Convert int32 group_sizes to int64 into the dedicated output buffer. - NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + // Convert int32 group_sizes to int64 into the dedicated output buffer (ragged tensors only). auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), - int64_sizes_ptr, num_gemms, stream); + if (any_ragged) { + NVTE_CHECK(active_group_sizes.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + nvte_convert_int32_to_int64( + reinterpret_cast(active_group_sizes.untyped_data()), int64_sizes_ptr, + num_gemms, stream); + } NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); @@ -656,14 +692,14 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -703,7 +739,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { NVTE_CHECK(lhs_is_trans && !rhs_is_trans, "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); @@ -732,7 +768,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty return ffi_with_cuda_error_check(); } - // Nominal case for FWD or DGRAD + // Nominal case for FWD, DGRAD, or batched GEMM //// RHS NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; @@ -748,14 +784,18 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::swap(lhsShape.data[0], lhsShape.data[1]); } auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + if (any_ragged) { + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, + lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); + } //// OUTPUT NVTEShape outShape{.data = {m, n}, .ndim = 2}; auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (any_ragged) { + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + } nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -769,33 +809,32 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes (int32) + .Arg() // lhs_group_sizes (G,) or empty (0,) + .Arg() // rhs_group_sizes (G,) or empty (0,) + .Arg() // out_group_sizes (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("M") - .Attr("N") - .Attr("K") .Attr("lhs_is_trans") .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("is_grouped_dense_wgrad"), + .Attr("scaling_mode"), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, + Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, + Buffer_Type out_group_sizes, Buffer_Type group_offset, + Result_Type output, Result_Type workspace, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + bool use_async_d2h_group_sizes) { // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -812,6 +851,37 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); + // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here. + (void)out_group_sizes; + + // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). + bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; + bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_ragged) + num_gemms = lhs_group_sizes.dimensions()[0]; + else if (is_rhs_ragged) + num_gemms = rhs_group_sizes.dimensions()[0]; + else + num_gemms = 1; // degenerate batched; legacy batched not a tested use case + const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + + // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. + NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); + NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); + size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + size_t m, n; + if (is_rhs_ragged) { + // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; + n = rhs_data.dimensions()[1]; + } else { + m = lhs_data.dimensions()[0]; // total M (sum of group sizes) + n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -824,9 +894,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -893,14 +960,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -916,25 +983,28 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + if (any_ragged) { + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + auto active_gs_ptr = + reinterpret_cast(active_group_sizes.untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), active_gs_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } + size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + if (!is_rhs_ragged) { + NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + ", got sum(group_sizes)=", sum_group_sizes); + } else { + NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + ", got sum(group_sizes)=", sum_group_sizes); + } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -982,7 +1052,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto lhs_shape_i = std::vector{m_i, k}; auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { size_t k_i = dim_list_host[i]; lhs_shape_i[0] = lhs_is_trans ? k_i : m; lhs_shape_i[1] = lhs_is_trans ? m : k_i; @@ -1172,23 +1242,21 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes + .Arg() // lhs_group_sizes (G,) or empty (0,) + .Arg() // rhs_group_sizes (G,) or empty (0,) + .Arg() // out_group_sizes (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("M") - .Attr("N") - .Attr("K") .Attr("lhs_is_trans") .Attr("rhs_is_trans") .Attr("scaling_mode") .Attr("has_bias") - .Attr("is_grouped_dense_wgrad") .Attr("use_async_d2h_group_sizes")); } // namespace jax diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 268995281c..e2a79fe9c7 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -523,15 +523,18 @@ def _grouped_dense_fwd_rule( # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout + empty_gs = jnp.empty((0,), jnp.int32) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=contracting_dims, + bias=bias, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) ctx = ( @@ -615,11 +618,14 @@ def _grouped_dense_bwd_rule( wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) + empty_gs = jnp.empty((0,), jnp.int32) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - group_sizes, - dgrad_contracting_dims, + lhs_group_sizes=group_sizes, + rhs_group_sizes=empty_gs, + out_group_sizes=group_sizes, + contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, @@ -628,8 +634,10 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - group_sizes, - wgrad_contracting_dims, + lhs_group_sizes=empty_gs, + rhs_group_sizes=group_sizes, + out_group_sizes=empty_gs, + contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, From 4a57485316db5c5f1c6bdf5c991c11e4d374259e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 09:58:07 -0700 Subject: [PATCH 02/23] Support first_dims and last_dims instead of a single group_sizes per tensor Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 18 ++- transformer_engine/jax/cpp_extensions/gemm.py | 99 ++++++++----- .../jax/csrc/extensions/gemm.cpp | 138 +++++++++++------- transformer_engine/jax/dense.py | 27 ++-- 4 files changed, 177 insertions(+), 105 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 02cc05649a..2f2d5383b2 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1793,9 +1793,12 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): )( lhs, rhs, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1832,9 +1835,12 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( lhs, rhs, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=contracting_dims, quantizer_set=quantizer_set, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index c298e19bf0..32500c9676 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1446,12 +1446,13 @@ class GroupedGemmPrimitive(BasePrimitive): Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, lhs_group_sizes, rhs_group_sizes, out_group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, + # lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, + # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (10, 11, 12, 13, 14, 15, 16) + impl_static_args = (13, 14, 15, 16, 17, 18, 19) inner_primitive = None outer_primitive = None @@ -1462,9 +1463,12 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - lhs_group_sizes_aval, - rhs_group_sizes_aval, - out_group_sizes_aval, + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval lhs_is_trans, rhs_is_trans, @@ -1504,11 +1508,11 @@ def abstract( del has_bias, use_async_d2h_group_sizes # Determine mode from which group_sizes buffer is non-empty - is_wgrad = rhs_group_sizes_aval.size > 0 + is_wgrad = rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0 num_groups = ( - lhs_group_sizes_aval.size - or rhs_group_sizes_aval.size - or out_group_sizes_aval.size + lhs_first_dims_aval.size or lhs_last_dims_aval.size + or rhs_first_dims_aval.size or rhs_last_dims_aval.size + or out_first_dims_aval.size or out_last_dims_aval.size or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 ) @@ -1650,9 +1654,12 @@ def impl( rhs_data, rhs_scale_inv, bias, - lhs_group_sizes, - rhs_group_sizes, - out_group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) lhs_is_trans, @@ -1674,9 +1681,12 @@ def impl( rhs_data, rhs_scale_inv, bias, - lhs_group_sizes, - rhs_group_sizes, - out_group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, *additional_args, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, @@ -2038,9 +2048,12 @@ def _flatten_to_2d(data, flatten_axis): def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - lhs_group_sizes: jnp.ndarray = None, # (G,) int32 if lhs first-dim is ragged, else None/(0,) - rhs_group_sizes: jnp.ndarray = None, # (G,) int32 if rhs first-dim is ragged (wgrad), else None/(0,) - out_group_sizes: jnp.ndarray = None, # (G,) int32 if output first-dim is ragged, else None/(0,) + lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) + lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) + rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) + rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) + out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) + out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -2055,9 +2068,12 @@ def grouped_gemm( Args: lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else None or empty (0,) sentinel - rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad mode), else None/(0,) - out_group_sizes: (G,) int32 if output first-dim is ragged, else None/(0,) + lhs_first_dims: (G,) int32 if LHS squashed first dim varies per group, else None/(0,) + lhs_last_dims: (G,) int32 if LHS squashed last dim varies per group, else None/(0,) + rhs_first_dims: (G,) int32 if RHS squashed first dim varies per group (wgrad), else None/(0,) + rhs_last_dims: (G,) int32 if RHS squashed last dim varies per group, else None/(0,) + out_first_dims: (G,) int32 if output first dim varies per group, else None/(0,) + out_last_dims: (G,) int32 if output last dim varies per group, else None/(0,) contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -2079,12 +2095,12 @@ def grouped_gemm( # Replace None sentinels with empty (0,) int32 arrays. empty_gs = jnp.empty((0,), jnp.int32) - if lhs_group_sizes is None: - lhs_group_sizes = empty_gs - if rhs_group_sizes is None: - rhs_group_sizes = empty_gs - if out_group_sizes is None: - out_group_sizes = empty_gs + lhs_first_dims = empty_gs if lhs_first_dims is None else lhs_first_dims + lhs_last_dims = empty_gs if lhs_last_dims is None else lhs_last_dims + rhs_first_dims = empty_gs if rhs_first_dims is None else rhs_first_dims + rhs_last_dims = empty_gs if rhs_last_dims is None else rhs_last_dims + out_first_dims = empty_gs if out_first_dims is None else out_first_dims + out_last_dims = empty_gs if out_last_dims is None else out_last_dims if isinstance(lhs, jnp.ndarray): assert isinstance(rhs, jnp.ndarray) @@ -2128,7 +2144,7 @@ def grouped_gemm( # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? if ( - rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) ): @@ -2140,7 +2156,7 @@ def grouped_gemm( # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. if ( - rhs_group_sizes.size > 0 # wgrad mode: rhs first-dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and isinstance(lhs, GroupedScaledTensor1x) and scaling_mode.is_1d_block_scaling() ): @@ -2168,7 +2184,11 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - active_group_sizes = lhs_group_sizes if lhs_group_sizes.size > 0 else rhs_group_sizes + active_group_sizes = next( + (gs for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0), + empty_gs, + ) lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis @@ -2222,7 +2242,7 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if lhs_group_sizes.size > 0: # fwd/dgrad: rhs has G as first dim + if lhs_first_dims.size > 0 or lhs_last_dims.size > 0: # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) @@ -2233,7 +2253,11 @@ def grouped_gemm( lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) - num_gemms = lhs_group_sizes.size or rhs_group_sizes.size or out_group_sizes.size + num_gemms = ( + lhs_first_dims.size or lhs_last_dims.size + or rhs_first_dims.size or rhs_last_dims.size + or out_first_dims.size or out_last_dims.size + ) has_bias = bias is not None if has_bias: @@ -2264,9 +2288,12 @@ def grouped_gemm( rhs_data_2d, rhs_scale_inv, bias, - lhs_group_sizes, - rhs_group_sizes, - out_group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, additional_arg_1, lhs_is_trans=lhs_is_trans, diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 834a7b9a5f..4387354f2a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -562,8 +562,10 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, - Buffer_Type out_group_sizes, Buffer_Type alpha, Buffer_Type beta, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, bool lhs_is_trans, bool rhs_is_trans, @@ -582,25 +584,31 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // C: column-major with size [m, n] --> row-major with size [n, m]. // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here - // as the output shape is inferred from lhs/rhs dims and passed to nvte_grouped_gemm implicitly. - (void)out_group_sizes; - - // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). - bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; - bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_out_first_ragged = out_first_dims.element_count() > 0; + bool is_out_last_ragged = out_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_ragged) - num_gemms = lhs_group_sizes.dimensions()[0]; - else if (is_rhs_ragged) - num_gemms = rhs_group_sizes.dimensions()[0]; - else if (out_group_sizes.element_count() > 0) - num_gemms = out_group_sizes.dimensions()[0]; - else - num_gemms = alpha.element_count(); // batched: no ragged tensor - const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; + else if (is_out_first_ragged) num_gemms = out_first_dims.dimensions()[0]; + else if (is_out_last_ragged) num_gemms = out_last_dims.dimensions()[0]; + else num_gemms = alpha.element_count(); // batched: no ragged tensor + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -632,10 +640,11 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // Convert int32 group_sizes to int64 into the dedicated output buffer (ragged tensors only). auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); if (any_ragged) { - NVTE_CHECK(active_group_sizes.element_type() == xla::ffi::DataType::S32, + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + NVTE_CHECK(active_gs_ptr->element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); nvte_convert_int32_to_int64( - reinterpret_cast(active_group_sizes.untyped_data()), int64_sizes_ptr, + reinterpret_cast(active_gs_ptr->untyped_data()), int64_sizes_ptr, num_gemms, stream); } @@ -746,13 +755,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty //// RHS NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_rhs_first_ragged) + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_rhs_last_ragged) + rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); //// LHS NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; lhs_is_trans = true; auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_lhs_first_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_lhs_last_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); //// OUTPUT NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; @@ -784,18 +799,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::swap(lhsShape.data[0], lhsShape.data[1]); } auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (any_ragged) { - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); - } + if (is_lhs_first_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + if (is_lhs_last_ragged) + lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); //// OUTPUT NVTEShape outShape{.data = {m, n}, .ndim = 2}; auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, num_gemms, outShape); - if (any_ragged) { + if (is_out_first_ragged) out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - } + if (is_out_last_ragged) + out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -814,9 +830,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // lhs_group_sizes (G,) or empty (0,) - .Arg() // rhs_group_sizes (G,) or empty (0,) - .Arg() // out_group_sizes (G,) or empty (0,) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output @@ -830,8 +849,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type lhs_group_sizes, Buffer_Type rhs_group_sizes, - Buffer_Type out_group_sizes, Buffer_Type group_offset, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, bool lhs_is_trans, bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, bool use_async_d2h_group_sizes) { @@ -851,22 +872,27 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); - // out_group_sizes is the sentinel for the output tensor's ragged dimension; unused directly here. - (void)out_group_sizes; - - // Determine which group_sizes buffer is active (non-empty sentinel = ragged dimension). - bool is_lhs_ragged = lhs_group_sizes.element_count() > 0; - bool is_rhs_ragged = rhs_group_sizes.element_count() > 0; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_ragged) - num_gemms = lhs_group_sizes.dimensions()[0]; - else if (is_rhs_ragged) - num_gemms = rhs_group_sizes.dimensions()[0]; - else - num_gemms = 1; // degenerate batched; legacy batched not a tested use case - const Buffer_Type &active_group_sizes = is_lhs_ragged ? lhs_group_sizes : rhs_group_sizes; + if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; + else num_gemms = 1; // degenerate batched; legacy batched not a tested use case + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -990,9 +1016,10 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); } else { - auto active_gs_ptr = - reinterpret_cast(active_group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), active_gs_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + auto gs_data_ptr = + reinterpret_cast(active_gs_ptr->untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, stream); // Note: This may break cudaGraph. cudaStreamSynchronize(stream); @@ -1247,9 +1274,12 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // lhs_group_sizes (G,) or empty (0,) - .Arg() // rhs_group_sizes (G,) or empty (0,) - .Arg() // out_group_sizes (G,) or empty (0,) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index e2a79fe9c7..ed4d0aa082 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -527,9 +527,12 @@ def _grouped_dense_fwd_rule( output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=contracting_dims, bias=bias, precision=precision, @@ -622,9 +625,12 @@ def _grouped_dense_bwd_rule( dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - lhs_group_sizes=group_sizes, - rhs_group_sizes=empty_gs, - out_group_sizes=group_sizes, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=empty_gs, + rhs_last_dims=empty_gs, + out_first_dims=group_sizes, + out_last_dims=empty_gs, contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, @@ -634,9 +640,12 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - lhs_group_sizes=empty_gs, - rhs_group_sizes=group_sizes, - out_group_sizes=empty_gs, + lhs_first_dims=group_sizes, + lhs_last_dims=empty_gs, + rhs_first_dims=group_sizes, + rhs_last_dims=empty_gs, + out_first_dims=empty_gs, + out_last_dims=empty_gs, contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, From 345d940869181f9c2e57820cffcc69b277d78903 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 10:22:31 -0700 Subject: [PATCH 03/23] Refactor GMM FFIs to store static attrs as structs Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/gemm.cpp | 131 ++++++++++++++++-- 1 file changed, 118 insertions(+), 13 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 4387354f2a..24026b4ad9 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -559,6 +559,117 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } +// Config structs for grouped GEMM FFI static attributes. +// Consolidating all static attributes into a single dict attribute makes it easy to add new +// attributes in the future with backwards-compatible defaults: if old HLO was generated without a +// newer attribute, DecodeAttrOrDefault leaves the field at its struct default value. +struct GroupedGemmV2Config { + bool lhs_is_trans = false; + bool rhs_is_trans = false; + JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans = false; + bool rhs_is_trans = false; + JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; + bool has_bias = false; + bool use_async_d2h_group_sizes = false; +}; + +} // namespace jax +} // namespace transformer_engine + +// Register AttrsBinding and AttrDecoding for grouped GEMM config structs. +// Uses a custom AttrDecoding (instead of XLA_FFI_REGISTER_STRUCT_ATTR_DECODING) that supports +// optional struct fields with default values, so old HLO without newer attributes still decodes. +namespace xla::ffi { + +namespace { + +// Finds an attribute by name. Returns its index or std::nullopt if absent. +std::optional FindAttrByName(const XLA_FFI_Attrs* attrs, std::string_view name) { + for (int64_t i = 0; i < attrs->size; ++i) { + if (std::string_view{attrs->names[i]->ptr, attrs->names[i]->len} == name) return i; + } + return std::nullopt; +} + +// Decodes a named attribute into `field` if present; leaves `field` at its default if absent. +// Returns false only when the attribute is present but fails to decode. +template +bool DecodeAttrOrDefault(const XLA_FFI_Attrs* attrs, std::string_view name, T& field, + DiagnosticEngine& diagnostic) { + auto idx = FindAttrByName(attrs, name); + if (!idx.has_value()) return true; // absent → keep default + auto decoded = AttrDecoding::Decode(attrs->types[*idx], attrs->attrs[*idx], diagnostic); + if (!decoded.has_value()) return false; + field = *decoded; + return true; +} + +} // namespace + +template <> +struct AttrsBinding { + using Attrs = transformer_engine::jax::GroupedGemmV2Config; +}; + +template <> +struct AttrDecoding { + using Type = transformer_engine::jax::GroupedGemmV2Config; + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + return diagnostic.Emit("Expected dictionary attribute for GroupedGemmV2Config"); + } + auto* attrs = reinterpret_cast(attr); + Type config; + if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) + return std::nullopt; + return config; + } +}; + +template <> +struct AttrsBinding { + using Attrs = transformer_engine::jax::GroupedGemmConfig; +}; + +template <> +struct AttrDecoding { + using Type = transformer_engine::jax::GroupedGemmConfig; + static std::optional Decode(XLA_FFI_AttrType type, void* attr, + DiagnosticEngine& diagnostic) { + if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { + return diagnostic.Emit("Expected dictionary attribute for GroupedGemmConfig"); + } + auto* attrs = reinterpret_cast(attr); + Type config; + if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) + return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", + config.use_async_d2h_group_sizes, diagnostic)) + return std::nullopt; + return config; + } +}; + +} // namespace xla::ffi + +namespace transformer_engine { +namespace jax { + // This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, @@ -568,8 +679,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, - bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode) { + GroupedGemmV2Config config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -842,9 +953,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode"), + .Attrs(), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, @@ -853,9 +962,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, Buffer_Type out_first_dims, Buffer_Type out_last_dims, Buffer_Type group_offset, - Result_Type output, Result_Type workspace, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool use_async_d2h_group_sizes) { + Result_Type output, Result_Type workspace, + GroupedGemmConfig config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -1283,11 +1392,7 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("has_bias") - .Attr("use_async_d2h_group_sizes")); + .Attrs()); } // namespace jax } // namespace transformer_engine From ed9c8e4e275b563a1a35a5eeafc5cf46052c58c7 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 10 Mar 2026 17:25:19 +0000 Subject: [PATCH 04/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 59 ++++----- .../jax/csrc/extensions/gemm.cpp | 117 ++++++++++-------- 2 files changed, 97 insertions(+), 79 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 32500c9676..d79be04983 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1510,9 +1510,12 @@ def abstract( # Determine mode from which group_sizes buffer is non-empty is_wgrad = rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0 num_groups = ( - lhs_first_dims_aval.size or lhs_last_dims_aval.size - or rhs_first_dims_aval.size or rhs_last_dims_aval.size - or out_first_dims_aval.size or out_last_dims_aval.size + lhs_first_dims_aval.size + or lhs_last_dims_aval.size + or rhs_first_dims_aval.size + or rhs_last_dims_aval.size + or out_first_dims_aval.size + or out_last_dims_aval.size or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 ) @@ -1527,11 +1530,7 @@ def abstract( # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) # dim[0] is always total M for fwd/dgrad M = lhs_data_aval.shape[0] - N = ( - rhs_data_aval.shape[1] - if not rhs_is_trans - else rhs_data_aval.shape[0] // num_groups - ) + N = rhs_data_aval.shape[1] if not rhs_is_trans else rhs_data_aval.shape[0] // num_groups out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( @@ -2048,12 +2047,12 @@ def _flatten_to_2d(data, flatten_axis): def grouped_gemm( lhs: Union[jnp.ndarray, GroupedScaledTensor1x], rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) - lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) - rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) - rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) - out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) - out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) + lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) + lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) + rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) + rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) + out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) + out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -2118,12 +2117,8 @@ def grouped_gemm( rhs_shape = rhs.original_shape lhs_fa = lhs.flatten_axis rhs_fa = rhs.flatten_axis - lhs_data = lhs.data.reshape( - math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:]) - ) - rhs_data = rhs.data.reshape( - math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:]) - ) + lhs_data = lhs.data.reshape(math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:])) + rhs_data = rhs.data.reshape(math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:])) lhs_scale_inv = lhs.scale_inv rhs_scale_inv = rhs.scale_inv assert lhs.scaling_mode == rhs.scaling_mode @@ -2144,7 +2139,7 @@ def grouped_gemm( # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) ): @@ -2156,7 +2151,7 @@ def grouped_gemm( # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged + (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged and isinstance(lhs, GroupedScaledTensor1x) and scaling_mode.is_1d_block_scaling() ): @@ -2185,8 +2180,11 @@ def grouped_gemm( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) active_group_sizes = next( - (gs for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] - if gs.size > 0), + ( + gs + for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0 + ), empty_gs, ) lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) @@ -2242,7 +2240,9 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if lhs_first_dims.size > 0 or lhs_last_dims.size > 0: # fwd/dgrad: rhs has G as first dim + if ( + lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + ): # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) @@ -2254,9 +2254,12 @@ def grouped_gemm( rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) num_gemms = ( - lhs_first_dims.size or lhs_last_dims.size - or rhs_first_dims.size or rhs_last_dims.size - or out_first_dims.size or out_last_dims.size + lhs_first_dims.size + or lhs_last_dims.size + or rhs_first_dims.size + or rhs_last_dims.size + or out_first_dims.size + or out_last_dims.size ) has_bias = bias is not None diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 24026b4ad9..e85c520916 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -588,7 +588,7 @@ namespace xla::ffi { namespace { // Finds an attribute by name. Returns its index or std::nullopt if absent. -std::optional FindAttrByName(const XLA_FFI_Attrs* attrs, std::string_view name) { +std::optional FindAttrByName(const XLA_FFI_Attrs *attrs, std::string_view name) { for (int64_t i = 0; i < attrs->size; ++i) { if (std::string_view{attrs->names[i]->ptr, attrs->names[i]->len} == name) return i; } @@ -598,8 +598,8 @@ std::optional FindAttrByName(const XLA_FFI_Attrs* attrs, std::string_vi // Decodes a named attribute into `field` if present; leaves `field` at its default if absent. // Returns false only when the attribute is present but fails to decode. template -bool DecodeAttrOrDefault(const XLA_FFI_Attrs* attrs, std::string_view name, T& field, - DiagnosticEngine& diagnostic) { +bool DecodeAttrOrDefault(const XLA_FFI_Attrs *attrs, std::string_view name, T &field, + DiagnosticEngine &diagnostic) { auto idx = FindAttrByName(attrs, name); if (!idx.has_value()) return true; // absent → keep default auto decoded = AttrDecoding::Decode(attrs->types[*idx], attrs->attrs[*idx], diagnostic); @@ -618,12 +618,12 @@ struct AttrsBinding { template <> struct AttrDecoding { using Type = transformer_engine::jax::GroupedGemmV2Config; - static std::optional Decode(XLA_FFI_AttrType type, void* attr, - DiagnosticEngine& diagnostic) { + static std::optional Decode(XLA_FFI_AttrType type, void *attr, + DiagnosticEngine &diagnostic) { if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { return diagnostic.Emit("Expected dictionary attribute for GroupedGemmV2Config"); } - auto* attrs = reinterpret_cast(attr); + auto *attrs = reinterpret_cast(attr); Type config; if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) return std::nullopt; @@ -643,12 +643,12 @@ struct AttrsBinding { template <> struct AttrDecoding { using Type = transformer_engine::jax::GroupedGemmConfig; - static std::optional Decode(XLA_FFI_AttrType type, void* attr, - DiagnosticEngine& diagnostic) { + static std::optional Decode(XLA_FFI_AttrType type, void *attr, + DiagnosticEngine &diagnostic) { if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { return diagnostic.Emit("Expected dictionary attribute for GroupedGemmConfig"); } - auto* attrs = reinterpret_cast(attr); + auto *attrs = reinterpret_cast(attr); Type config; if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) return std::nullopt; @@ -656,10 +656,9 @@ struct AttrDecoding { return std::nullopt; if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", - config.use_async_d2h_group_sizes, diagnostic)) + if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) return std::nullopt; + if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", config.use_async_d2h_group_sizes, + diagnostic)) return std::nullopt; return config; } @@ -676,10 +675,9 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, Buffer_Type out_first_dims, Buffer_Type out_last_dims, - Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, - GroupedGemmV2Config config) { + Buffer_Type alpha, Buffer_Type beta, Result_Type output, + Result_Type cublas_workspace, Result_Type setup_workspace, + Result_Type int64_workspace, GroupedGemmV2Config config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: @@ -697,29 +695,40 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty // Determine which group_sizes buffers are active (non-empty = ragged dimension). bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; - bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; - bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; bool is_out_first_ragged = out_first_dims.element_count() > 0; - bool is_out_last_ragged = out_last_dims.element_count() > 0; + bool is_out_last_ragged = out_last_dims.element_count() > 0; bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; - else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; - else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; - else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; - else if (is_out_first_ragged) num_gemms = out_first_dims.dimensions()[0]; - else if (is_out_last_ragged) num_gemms = out_last_dims.dimensions()[0]; - else num_gemms = alpha.element_count(); // batched: no ragged tensor + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else if (is_out_first_ragged) + num_gemms = out_first_dims.dimensions()[0]; + else if (is_out_last_ragged) + num_gemms = out_last_dims.dimensions()[0]; + else + num_gemms = alpha.element_count(); // batched: no ragged tensor const Buffer_Type *active_gs_ptr = nullptr; - if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; - else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; - else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; - else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -754,9 +763,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); NVTE_CHECK(active_gs_ptr->element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); - nvte_convert_int32_to_int64( - reinterpret_cast(active_gs_ptr->untyped_data()), int64_sizes_ptr, - num_gemms, stream); + nvte_convert_int32_to_int64(reinterpret_cast(active_gs_ptr->untyped_data()), + int64_sizes_ptr, num_gemms, stream); } NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, @@ -961,8 +969,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, Buffer_Type out_first_dims, Buffer_Type out_last_dims, - Buffer_Type group_offset, - Result_Type output, Result_Type workspace, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, GroupedGemmConfig config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; // Notes on matrix layouts and transpose: @@ -983,25 +990,34 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // Determine which group_sizes buffers are active (non-empty = ragged dimension). bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; - bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; - bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; size_t num_gemms; - if (is_lhs_first_ragged) num_gemms = lhs_first_dims.dimensions()[0]; - else if (is_lhs_last_ragged) num_gemms = lhs_last_dims.dimensions()[0]; - else if (is_rhs_first_ragged) num_gemms = rhs_first_dims.dimensions()[0]; - else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; - else num_gemms = 1; // degenerate batched; legacy batched not a tested use case + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else + num_gemms = 1; // degenerate batched; legacy batched not a tested use case const Buffer_Type *active_gs_ptr = nullptr; - if (is_lhs_first_ragged) active_gs_ptr = &lhs_first_dims; - else if (is_lhs_last_ragged) active_gs_ptr = &lhs_last_dims; - else if (is_rhs_first_ragged) active_gs_ptr = &rhs_first_dims; - else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); @@ -1126,8 +1142,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); } else { NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); - auto gs_data_ptr = - reinterpret_cast(active_gs_ptr->untyped_data()); + auto gs_data_ptr = reinterpret_cast(active_gs_ptr->untyped_data()); cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, stream); // Note: This may break cudaGraph. From ed0deaf08a01da3357fe09cbe92ca31fcd6beac0 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 11:35:29 -0700 Subject: [PATCH 05/23] Cleanup C++ v2 FFI Signed-off-by: Jeremy Berchtold --- .../jax/csrc/extensions/gemm.cpp | 308 +++++------------- 1 file changed, 77 insertions(+), 231 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e85c520916..288770281a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -559,6 +559,70 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } +// V2 variant: derives data shape from the 2D XLA buffer directly, converts group_sizes +// int32→int64 per-tensor into int64_workspace, and wires first_dims/last_dims. +// Only NO_SCALING is supported. +JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, + Buffer_Type const &first_dims, + Buffer_Type const &last_dims, + Result_Type int64_workspace, + size_t num_gemms, + cudaStream_t stream) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. + // Input buffers (lhs, rhs) are already 2D from the Python side. Output buffers may be ND + // (e.g. [G, K, N] for wgrad), so we collapse dims[0..N-2] → rows and keep dims[N-1] → cols. + NVTEShape dataShape{.data = {product(dims, 0, dims.size() - 1), dims[dims.size() - 1]}, + .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); + wrapper.set_rowwise(data, std::nullopt); + auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + nvte_convert_int32_to_int64( + reinterpret_cast(first_dims.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, + "group_sizes must be int32."); + nvte_convert_int32_to_int64( + reinterpret_cast(last_dims.untyped_data()), + int64_sizes_ptr, num_gemms, stream); + wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); + } + return wrapper; +} + +// Returns num_gemms from the first non-empty per-tensor group_sizes buffer, +// falling back to the element count of alpha for the uniform-batch case. +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, + Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, + Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, + Buffer_Type const &out_last_dims, + Buffer_Type const &alpha) { + if (lhs_first_dims.element_count() > 0) { + return lhs_first_dims.dimensions()[0]; + } else if (lhs_last_dims.element_count() > 0) { + return lhs_last_dims.dimensions()[0]; + } else if (rhs_first_dims.element_count() > 0) { + return rhs_first_dims.dimensions()[0]; + } else if (rhs_last_dims.element_count() > 0) { + return rhs_last_dims.dimensions()[0]; + } else if (out_first_dims.element_count() > 0) { + return out_first_dims.dimensions()[0]; + } else if (out_last_dims.element_count() > 0) { + return out_last_dims.dimensions()[0]; + } else { + return alpha.element_count(); // uniform batch: no ragged tensor + } +} + // Config structs for grouped GEMM FFI static attributes. // Consolidating all static attributes into a single dict attribute makes it easy to add new // attributes in the future with backwards-compatible defaults: if old HLO was generated without a @@ -679,181 +743,19 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, GroupedGemmV2Config config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; - // Notes on matrix layouts and transpose: - // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major [m, k] for N - [k, m] for T - // B: row-major [k, n] for N - [n, k] for T - // on exiting this function, JAX expect: - // C: row-major with size [m, n]. - // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m] for T - [m, k] for N - // B: column-major with size [n, k] for T - [k, n] for N - // - // If we call cuBLAS GEMM for A * B, the output will be: - // C: column-major with size [m, n] --> row-major with size [n, m]. - // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - - // Determine which group_sizes buffers are active (non-empty = ragged dimension). - bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; - bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; - bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; - bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; - bool is_out_first_ragged = out_first_dims.element_count() > 0; - bool is_out_last_ragged = out_last_dims.element_count() > 0; - bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; - bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; - bool any_ragged = is_lhs_ragged || is_rhs_ragged; - - size_t num_gemms; - if (is_lhs_first_ragged) - num_gemms = lhs_first_dims.dimensions()[0]; - else if (is_lhs_last_ragged) - num_gemms = lhs_last_dims.dimensions()[0]; - else if (is_rhs_first_ragged) - num_gemms = rhs_first_dims.dimensions()[0]; - else if (is_rhs_last_ragged) - num_gemms = rhs_last_dims.dimensions()[0]; - else if (is_out_first_ragged) - num_gemms = out_first_dims.dimensions()[0]; - else if (is_out_last_ragged) - num_gemms = out_last_dims.dimensions()[0]; - else - num_gemms = alpha.element_count(); // batched: no ragged tensor - - const Buffer_Type *active_gs_ptr = nullptr; - if (is_lhs_first_ragged) - active_gs_ptr = &lhs_first_dims; - else if (is_lhs_last_ragged) - active_gs_ptr = &lhs_last_dims; - else if (is_rhs_first_ragged) - active_gs_ptr = &rhs_first_dims; - else if (is_rhs_last_ragged) - active_gs_ptr = &rhs_last_dims; - - // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. - NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); - NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); - size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; - size_t m, n; - if (is_rhs_ragged) { - // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M - m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; - n = rhs_data.dimensions()[1]; - } else { - m = lhs_data.dimensions()[0]; // total M (sum of group sizes) - n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; - } - - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); - bool has_bias = product(bias.dimensions()) > 0; - auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - - // Convert int32 group_sizes to int64 into the dedicated output buffer (ragged tensors only). - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - if (any_ragged) { - NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); - NVTE_CHECK(active_gs_ptr->element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); - nvte_convert_int32_to_int64(reinterpret_cast(active_gs_ptr->untyped_data()), - int64_sizes_ptr, num_gemms, stream); - } NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); - // It is weird that TE/Common GEMM only use colwise for MXFP8 - const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); - const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; - const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; - const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, + rhs_first_dims, rhs_last_dims, + out_first_dims, out_last_dims, alpha); - // Outputs - auto out_ptr = reinterpret_cast(output->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); + // Workspaces. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); - // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); - auto workspace_total_size = product(cublas_workspace->dimensions()); - - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - const size_t workspace_alignment_padding = 256; - const size_t tensor_scaling_sinv_aligment = 16; - const size_t mxfp8_scaling_sinv_alignment_padding = 256; - auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { - // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned - // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. - workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); - } - auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned - auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; - - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); - - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); - - size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); - size_t actual_lhs_size = product(lhs_data.dimensions()); - size_t actual_rhs_size = product(rhs_data.dimensions()); - size_t actual_out_size = product(output->dimensions()); - NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", - expected_lhs_size, ", got ", actual_lhs_size); - if (!is_rhs_ragged) { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, - "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, - " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, - " * ", n, " = ", expected_out_size, ", got ", actual_out_size); - } else { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, - " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, - "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, - " = ", expected_out_size, ", got ", actual_out_size); - } - - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - bool grad = false; - bool accumulate = false; - bool use_split_accumulator = false; - auto bias_shape = std::vector{has_bias ? n : 0}; - const int arch = cuda::sm_arch(); - - if (arch < 100 && is_fp8_gemm) { - NVTE_CHECK(!lhs_is_trans && rhs_is_trans, - "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", - "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); - } - + auto workspace_size = product(cublas_workspace->dimensions()) - 256; TensorWrapper workspace_setup(setup_workspace_ptr, std::vector{product(setup_workspace->dimensions())}, DType::kByte); @@ -867,70 +769,14 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_rhs_ragged) { - NVTE_CHECK(lhs_is_trans && !rhs_is_trans, - "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); - - //// RHS - NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - if (is_rhs_first_ragged) - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_rhs_last_ragged) - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); - - //// LHS - NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; - lhs_is_trans = true; - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (is_lhs_first_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_lhs_last_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); - - //// OUTPUT - NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, - alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), - workspace_cublas.data(), - nullptr, // config (use defaults) - stream); - - return ffi_with_cuda_error_check(); - } - - // Nominal case for FWD, DGRAD, or batched GEMM - - //// RHS - NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; - if (rhs_is_trans) { - rhsShape.data[0] = num_gemms * n; - rhsShape.data[1] = k; - } - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - - //// LHS - NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; - if (lhs_is_trans) { - std::swap(lhsShape.data[0], lhsShape.data[1]); - } - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - if (is_lhs_first_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_lhs_last_ragged) - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); - - //// OUTPUT - NVTEShape outShape{.data = {m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - if (is_out_first_ragged) - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - if (is_out_last_ragged) - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); + // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed. + // int32→int64 conversion for group_sizes is handled per-tensor inside make_grouped_tensor. + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, + int64_workspace, num_gemms, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, + int64_workspace, num_gemms, stream); + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, + int64_workspace, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), From 88bb7daaa6ff1d0877b2227c2cd29efcf29cd555 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 11:55:59 -0700 Subject: [PATCH 06/23] Fix int64 workspace usage Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 19 +++++++- .../jax/csrc/extensions/gemm.cpp | 47 +++++++++++++------ 2 files changed, 50 insertions(+), 16 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index d79be04983..979bed9577 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1549,7 +1549,24 @@ def abstract( shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8 ) # Temporary buffer for int32 -> int64 conversion of group_sizes on device. - int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize + # Each non-empty *_dims buffer needs its own slot of num_groups int64 elements so that + # make_grouped_tensor can write to a distinct region per ragged dimension. Allocate + # exactly as many slots as there are non-empty buffers (minimum 1 to avoid zero-size). + num_ragged_dim_buffers = sum( + 1 + for aval in [ + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + ] + if aval.size > 0 + ) + int64_workspace_size = ( + max(num_ragged_dim_buffers, 1) * num_groups * jnp.dtype(jnp.int64).itemsize + ) int64_workspace_aval = jax.core.ShapedArray( shape=(int64_workspace_size,), dtype=jnp.uint8 ) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 288770281a..e122e8f909 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -559,13 +559,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// V2 variant: derives data shape from the 2D XLA buffer directly, converts group_sizes -// int32→int64 per-tensor into int64_workspace, and wires first_dims/last_dims. -// Only NO_SCALING is supported. +// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. +// int64_offset (in int64 elements) is updated on return to the next available slot so callers can +// thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked +// before each slot is used. Only NO_SCALING is supported. JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, - Result_Type int64_workspace, + int64_t *int64_workspace_base, + size_t int64_workspace_capacity, + size_t &int64_offset, size_t num_gemms, cudaStream_t stream) { auto dims = data.dimensions(); @@ -577,22 +581,27 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); if (first_dims.element_count() > 0) { NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; nvte_convert_int32_to_int64( - reinterpret_cast(first_dims.untyped_data()), - int64_sizes_ptr, num_gemms, stream); - wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + reinterpret_cast(first_dims.untyped_data()), slot, num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; } if (last_dims.element_count() > 0) { NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; nvte_convert_int32_to_int64( - reinterpret_cast(last_dims.untyped_data()), - int64_sizes_ptr, num_gemms, stream); - wrapper.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedLastDims); + reinterpret_cast(last_dims.untyped_data()), slot, num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; } return wrapper; } @@ -770,13 +779,21 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty convert_ffi_datatype_to_te_dtype(beta.element_type())); // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed. - // int32→int64 conversion for group_sizes is handled per-tensor inside make_grouped_tensor. + // int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each. + // int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims + // buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor. + auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); + size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); + size_t int64_offset = 0; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, - int64_workspace, num_gemms, stream); + int64_base, int64_capacity, int64_offset, num_gemms, + stream); auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, - int64_workspace, num_gemms, stream); + int64_base, int64_capacity, int64_offset, num_gemms, + stream); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, - int64_workspace, num_gemms, stream); + int64_base, int64_capacity, int64_offset, num_gemms, + stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), From 60312c85374c2ce4cf1d51e09a0e18aa59b157ee Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 15:59:12 -0700 Subject: [PATCH 07/23] Address greptile comments Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 979bed9577..de0ef1c522 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1528,8 +1528,8 @@ def abstract( out_shape = (num_groups, M, N) else: # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) - # dim[0] is always total M for fwd/dgrad - M = lhs_data_aval.shape[0] + # M is the non-contracting (output) dim + M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] N = rhs_data_aval.shape[1] if not rhs_is_trans else rhs_data_aval.shape[0] // num_groups out_shape = (M, N) @@ -2270,6 +2270,13 @@ def grouped_gemm( lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) + # Validate contracting dim size + k_lhs = lhs_data_2d.shape[0] if lhs_is_trans else lhs_data_2d.shape[1] + k_rhs = rhs_data_2d.shape[1] if rhs_is_trans else rhs_data_2d.shape[0] // num_gemms + assert k_lhs == k_rhs, ( + f"Contracting dimension mismatch: LHS K={k_lhs}, RHS K={k_rhs}" + ) + num_gemms = ( lhs_first_dims.size or lhs_last_dims.size @@ -2278,6 +2285,12 @@ def grouped_gemm( or out_first_dims.size or out_last_dims.size ) + if num_gemms == 0: + raise ValueError( + "grouped_gemm requires at least one non-empty dimension array " + "(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, " + "out_first_dims, or out_last_dims)." + ) has_bias = bias is not None if has_bias: From 025f598ab65fd5c06d9d0f167504fde748514a89 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 16:25:20 -0700 Subject: [PATCH 08/23] Refactor wgrad-specific checks to be generic for GMM in gemm.py Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 104 +++++++++++------- .../jax/csrc/extensions/gemm.cpp | 2 +- 2 files changed, 64 insertions(+), 42 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index de0ef1c522..b97b66066f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1441,6 +1441,45 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) +def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: + """Non-contracting output size M from the 2-D LHS buffer.""" + return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] + + +def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: + """Non-contracting output size N from the 2-D RHS buffer.""" + return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] + + +def _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups: int, +) -> None: + """Assert that all non-empty *_dims arrays have exactly num_groups elements. + + rhs_first_dims / rhs_last_dims describe the ragged contracting K dimension. + K totals need not fill the entire buffer (padding is allowed), so only the + array length is checked, not the per-group sum. + """ + for name, aval in [ + ("lhs_first_dims", lhs_first_dims_aval), + ("lhs_last_dims", lhs_last_dims_aval), + ("out_first_dims", out_first_dims_aval), + ("out_last_dims", out_last_dims_aval), + ("rhs_first_dims", rhs_first_dims_aval), + ("rhs_last_dims", rhs_last_dims_aval), + ]: + if aval.size > 0: + assert aval.size == num_groups, ( + f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" + ) + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). @@ -1507,8 +1546,6 @@ def abstract( del bias_aval del has_bias, use_async_d2h_group_sizes - # Determine mode from which group_sizes buffer is non-empty - is_wgrad = rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0 num_groups = ( lhs_first_dims_aval.size or lhs_last_dims_aval.size @@ -1519,18 +1556,28 @@ def abstract( or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 ) - # lhs_data_aval and rhs_data_aval are now 2D; derive output shape from buffer dims - if is_wgrad: - # lhs shape [K_lhs, M] (lhs_is_trans=True) or [M, K_lhs] (lhs_is_trans=False) - # M is the non-contracting (output) dim - M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] - N = rhs_data_aval.shape[1] + _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups, + ) + + # lhs_data_aval and rhs_data_aval are 2D; derive output shape from buffer dims. + # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) + # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) + M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) + N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) + # When rhs has a ragged (contracting) K dimension, M and N are fixed per group + # and the output has a leading group axis. + # K validation is intentionally skipped: per-group K values may not fill the + # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. + if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: out_shape = (num_groups, M, N) else: - # lhs shape [M_total, K] (lhs_is_trans=False) or [K, M_total] (lhs_is_trans=True) - # M is the non-contracting (output) dim - M = lhs_data_aval.shape[1] if lhs_is_trans else lhs_data_aval.shape[0] - N = rhs_data_aval.shape[1] if not rhs_is_trans else rhs_data_aval.shape[0] // num_groups out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( @@ -2150,30 +2197,12 @@ def grouped_gemm( lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) - # rhs_shape [G, K, N] - rhs_is_trans = rhs_contract_dim[0] != 1 + # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). + # This formula handles both standard rhs [G, K, N] (G-prefixed) and wgrad + # rhs [K_total, N] (no G prefix) without needing a separate wgrad override. + rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - # TODO(Hua): these are for fp16 dense wgrad, any better way to handle this? - if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged - and not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - ): - lhs_is_trans = True - rhs_is_trans = False - lhs_flatten_axis = 1 - rhs_flatten_axis = 1 - - # For MXFP8 block-scaling wgrad with pre-quantized inputs: rhs is colwise quantized, - # so rhs_use_colwise = (is_mxfp8 && !rhs_is_trans) must be True → rhs_is_trans=False. - if ( - (rhs_first_dims.size > 0 or rhs_last_dims.size > 0) # wgrad mode: rhs dim is ragged - and isinstance(lhs, GroupedScaledTensor1x) - and scaling_mode.is_1d_block_scaling() - ): - rhs_is_trans = False - if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2270,13 +2299,6 @@ def grouped_gemm( lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) - # Validate contracting dim size - k_lhs = lhs_data_2d.shape[0] if lhs_is_trans else lhs_data_2d.shape[1] - k_rhs = rhs_data_2d.shape[1] if rhs_is_trans else rhs_data_2d.shape[0] // num_gemms - assert k_lhs == k_rhs, ( - f"Contracting dimension mismatch: LHS K={k_lhs}, RHS K={k_rhs}" - ) - num_gemms = ( lhs_first_dims.size or lhs_last_dims.size diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index e122e8f909..9e996c8f3a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -892,7 +892,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; n = rhs_data.dimensions()[1]; } else { - m = lhs_data.dimensions()[0]; // total M (sum of group sizes) + m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; // total M (sum of group sizes) n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; } From 089e530d2f5b9a3732f1d2bff08bd9650ff62a73 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 16:58:35 -0700 Subject: [PATCH 09/23] Refactor XLA FFI struct setup Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions.h | 28 +++++ .../jax/csrc/extensions/gemm.cpp | 104 ------------------ 2 files changed, 28 insertions(+), 104 deletions(-) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 93c85aaacc..98a97084a1 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -45,6 +45,20 @@ struct ActivationConfig { ClampedSwigluConfig clamped_swiglu; }; +struct GroupedGemmV2Config { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + bool has_bias; + bool use_async_d2h_group_sizes; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -170,6 +184,20 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::ActivationConfig, ::xla::ffi::StructMember("clamped_swiglu")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmV2Config, + ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmConfig, + ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("has_bias"), + ::xla::ffi::StructMember("use_async_d2h_group_sizes")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 9e996c8f3a..f28deaaea7 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -632,113 +632,9 @@ size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, } } -// Config structs for grouped GEMM FFI static attributes. -// Consolidating all static attributes into a single dict attribute makes it easy to add new -// attributes in the future with backwards-compatible defaults: if old HLO was generated without a -// newer attribute, DecodeAttrOrDefault leaves the field at its struct default value. -struct GroupedGemmV2Config { - bool lhs_is_trans = false; - bool rhs_is_trans = false; - JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; -}; - -struct GroupedGemmConfig { - bool lhs_is_trans = false; - bool rhs_is_trans = false; - JAXX_Scaling_Mode scaling_mode = JAXX_Scaling_Mode::NO_SCALING; - bool has_bias = false; - bool use_async_d2h_group_sizes = false; -}; - } // namespace jax } // namespace transformer_engine -// Register AttrsBinding and AttrDecoding for grouped GEMM config structs. -// Uses a custom AttrDecoding (instead of XLA_FFI_REGISTER_STRUCT_ATTR_DECODING) that supports -// optional struct fields with default values, so old HLO without newer attributes still decodes. -namespace xla::ffi { - -namespace { - -// Finds an attribute by name. Returns its index or std::nullopt if absent. -std::optional FindAttrByName(const XLA_FFI_Attrs *attrs, std::string_view name) { - for (int64_t i = 0; i < attrs->size; ++i) { - if (std::string_view{attrs->names[i]->ptr, attrs->names[i]->len} == name) return i; - } - return std::nullopt; -} - -// Decodes a named attribute into `field` if present; leaves `field` at its default if absent. -// Returns false only when the attribute is present but fails to decode. -template -bool DecodeAttrOrDefault(const XLA_FFI_Attrs *attrs, std::string_view name, T &field, - DiagnosticEngine &diagnostic) { - auto idx = FindAttrByName(attrs, name); - if (!idx.has_value()) return true; // absent → keep default - auto decoded = AttrDecoding::Decode(attrs->types[*idx], attrs->attrs[*idx], diagnostic); - if (!decoded.has_value()) return false; - field = *decoded; - return true; -} - -} // namespace - -template <> -struct AttrsBinding { - using Attrs = transformer_engine::jax::GroupedGemmV2Config; -}; - -template <> -struct AttrDecoding { - using Type = transformer_engine::jax::GroupedGemmV2Config; - static std::optional Decode(XLA_FFI_AttrType type, void *attr, - DiagnosticEngine &diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { - return diagnostic.Emit("Expected dictionary attribute for GroupedGemmV2Config"); - } - auto *attrs = reinterpret_cast(attr); - Type config; - if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) - return std::nullopt; - return config; - } -}; - -template <> -struct AttrsBinding { - using Attrs = transformer_engine::jax::GroupedGemmConfig; -}; - -template <> -struct AttrDecoding { - using Type = transformer_engine::jax::GroupedGemmConfig; - static std::optional Decode(XLA_FFI_AttrType type, void *attr, - DiagnosticEngine &diagnostic) { - if (XLA_FFI_PREDICT_FALSE(type != XLA_FFI_AttrType_DICTIONARY)) { - return diagnostic.Emit("Expected dictionary attribute for GroupedGemmConfig"); - } - auto *attrs = reinterpret_cast(attr); - Type config; - if (!DecodeAttrOrDefault(attrs, "lhs_is_trans", config.lhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "rhs_is_trans", config.rhs_is_trans, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "scaling_mode", config.scaling_mode, diagnostic)) - return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "has_bias", config.has_bias, diagnostic)) return std::nullopt; - if (!DecodeAttrOrDefault(attrs, "use_async_d2h_group_sizes", config.use_async_d2h_group_sizes, - diagnostic)) - return std::nullopt; - return config; - } -}; - -} // namespace xla::ffi - namespace transformer_engine { namespace jax { From 8ad229483323db437bb3426d600a42c019aa3f67 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Tue, 10 Mar 2026 17:04:21 -0700 Subject: [PATCH 10/23] Fix edge case in TE v1 GMM Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/csrc/extensions/gemm.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index f28deaaea7..7795d4c18a 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -766,7 +766,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; else - num_gemms = 1; // degenerate batched; legacy batched not a tested use case + NVTE_CHECK(false, "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to determine num_gemms."); const Buffer_Type *active_gs_ptr = nullptr; if (is_lhs_first_ragged) From 4ff5d1d9ffd9e7c43410de5456f42d0c1b43921f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Wed, 11 Mar 2026 00:26:02 +0000 Subject: [PATCH 11/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 6 +- transformer_engine/jax/csrc/extensions.h | 6 +- .../jax/csrc/extensions/gemm.cpp | 56 ++++++++----------- 3 files changed, 29 insertions(+), 39 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1ef7da9cbc..1f15c27d97 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1313,9 +1313,9 @@ def _assert_grouped_gemm_dims_shapes( ("rhs_last_dims", rhs_last_dims_aval), ]: if aval.size > 0: - assert aval.size == num_groups, ( - f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" - ) + assert ( + aval.size == num_groups + ), f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" class GroupedGemmPrimitive(BasePrimitive): diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 9a5647f7c7..bd429a7db6 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -207,14 +207,12 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("use_split_accumulator")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::GroupedGemmV2Config, - ::xla::ffi::StructMember("lhs_is_trans"), + transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( - transformer_engine::jax::GroupedGemmConfig, - ::xla::ffi::StructMember("lhs_is_trans"), + transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode"), ::xla::ffi::StructMember("has_bias"), diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index dec72f809e..dd6c0a59f2 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -626,10 +626,8 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, - size_t int64_workspace_capacity, - size_t &int64_offset, - size_t num_gemms, - cudaStream_t stream) { + size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. @@ -640,24 +638,22 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { - NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, "int64_workspace overflow: not enough space for first_dims conversion."); auto *slot = int64_workspace_base + int64_offset; - nvte_convert_int32_to_int64( - reinterpret_cast(first_dims.untyped_data()), slot, num_gemms, stream); + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); int64_offset += num_gemms; } if (last_dims.element_count() > 0) { - NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, - "group_sizes must be int32."); + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, "int64_workspace overflow: not enough space for last_dims conversion."); auto *slot = int64_workspace_base + int64_offset; - nvte_convert_int32_to_int64( - reinterpret_cast(last_dims.untyped_data()), slot, num_gemms, stream); + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); int64_offset += num_gemms; } @@ -666,12 +662,9 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // Returns num_gemms from the first non-empty per-tensor group_sizes buffer, // falling back to the element count of alpha for the uniform-batch case. -size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, - Buffer_Type const &lhs_last_dims, - Buffer_Type const &rhs_first_dims, - Buffer_Type const &rhs_last_dims, - Buffer_Type const &out_first_dims, - Buffer_Type const &out_last_dims, +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, Buffer_Type const &out_last_dims, Buffer_Type const &alpha) { if (lhs_first_dims.element_count() > 0) { return lhs_first_dims.dimensions()[0]; @@ -710,9 +703,8 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); - size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, - rhs_first_dims, rhs_last_dims, - out_first_dims, out_last_dims, alpha); + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, + rhs_last_dims, out_first_dims, out_last_dims, alpha); // Workspaces. auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); @@ -739,15 +731,12 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, - int64_base, int64_capacity, int64_offset, num_gemms, - stream); - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, - int64_base, int64_capacity, int64_offset, num_gemms, - stream); - auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, - int64_base, int64_capacity, int64_offset, num_gemms, - stream); + auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); + auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -824,7 +813,9 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) num_gemms = rhs_last_dims.dimensions()[0]; else - NVTE_CHECK(false, "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to determine num_gemms."); + NVTE_CHECK(false, + "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to " + "determine num_gemms."); const Buffer_Type *active_gs_ptr = nullptr; if (is_lhs_first_ragged) @@ -846,7 +837,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; n = rhs_data.dimensions()[1]; } else { - m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; // total M (sum of group sizes) + m = lhs_is_trans ? lhs_data.dimensions()[1] + : lhs_data.dimensions()[0]; // total M (sum of group sizes) n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; } From 0cb7289643a8349817916be637dc5f247df6696b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Wed, 11 Mar 2026 10:13:34 -0700 Subject: [PATCH 12/23] Fix issues on Hopper Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/gemm.py | 13 ++++++++++--- transformer_engine/jax/csrc/extensions/gemm.cpp | 3 ++- 2 files changed, 12 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1f15c27d97..06af064c9f 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1408,14 +1408,16 @@ def abstract( # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) - N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) - # When rhs has a ragged (contracting) K dimension, M and N are fixed per group - # and the output has a leading group axis. # K validation is intentionally skipped: per-group K values may not fill the # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: + # Wgrad case: rhs has ragged contracting K dimension with no G-prefix. + # T-layout rhs shape is (N, K_total); N-layout rhs shape is (K_total, N). + N = rhs_data_aval.shape[0] if rhs_is_trans else rhs_data_aval.shape[1] out_shape = (num_groups, M, N) else: + # When rhs has a leading group axis, _grouped_gemm_rhs_N divides by num_groups. + N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) out_shape = (M, N) cublas_workspace_aval = jax.core.ShapedArray( @@ -1889,6 +1891,11 @@ def _can_use_v2_grouped_gemm( if not _v2_grouped_gemm_available: return False + # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). + # Fall back to the v1 path on SM90 (Hopper) and older architectures. + if get_device_compute_capability(0) < 100: + return False + return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index dd6c0a59f2..2d73390d33 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -835,7 +835,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type if (is_rhs_ragged) { // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; - n = rhs_data.dimensions()[1]; + // T-layout rhs: (N, K_total) -> n = dim[0]; N-layout rhs: (K_total, N) -> n = dim[1] + n = rhs_is_trans ? rhs_data.dimensions()[0] : rhs_data.dimensions()[1]; } else { m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; // total M (sum of group sizes) From cc236ad10ccea4edc7b740cf236776c3586c3381 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Thu, 12 Mar 2026 14:50:43 -0700 Subject: [PATCH 13/23] Refactor Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 35 ++- transformer_engine/jax/cpp_extensions/gemm.py | 254 +++++++++--------- .../jax/cpp_extensions/quantization.py | 3 +- transformer_engine/jax/csrc/extensions.h | 12 +- .../jax/csrc/extensions/gemm.cpp | 51 ++-- transformer_engine/jax/dense.py | 71 +++-- .../jax/quantize/dequantizer.py | 11 +- transformer_engine/jax/quantize/quantizer.py | 2 +- transformer_engine/jax/quantize/tensor.py | 159 +++++++++-- 9 files changed, 389 insertions(+), 209 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 2f2d5383b2..9fddbc435c 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -36,6 +36,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, QuantizerFactory, QuantizeLayout, @@ -1787,18 +1788,17 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm - empty_gs = jnp.empty((0,), jnp.int32) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( - lhs, - rhs, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, + lhs_tensor, + rhs_tensor, contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1831,16 +1831,15 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) - empty_gs = jnp.empty((0,), jnp.int32) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, - rhs, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, + lhs_tensor, + rhs_tensor, contracting_dims=contracting_dims, quantizer_set=quantizer_set, ) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 6a41cfc94e..ff9194bdd9 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -37,6 +37,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, Quantizer, GroupedQuantizer, @@ -1331,15 +1332,6 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) -def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: - """Non-contracting output size M from the 2-D LHS buffer.""" - return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] - - -def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: - """Non-contracting output size N from the 2-D RHS buffer.""" - return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] - def _assert_grouped_gemm_dims_shapes( lhs_first_dims_aval, @@ -1381,7 +1373,7 @@ class GroupedGemmPrimitive(BasePrimitive): # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (13, 14, 15, 16, 17, 18, 19) + impl_static_args = (13, 14, 15, 16, 17, 18, 19, 20, 21, 22) inner_primitive = None outer_primitive = None @@ -1406,19 +1398,22 @@ def abstract( has_bias, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): """ Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 2D array [rows, cols] + lhs_data: Left-hand side input matrix data, N-D array lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 2D array [rows, cols] + rhs_data: Right-hand side input matrix data, N-D array rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - lhs_group_sizes: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel - rhs_group_sizes: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel - out_group_sizes: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel + lhs_first_dims: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_first_dims: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_first_dims: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR @@ -1429,6 +1424,9 @@ def abstract( scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided + lhs_axis_boundary: Axis split point for lhs N-D → 2D flattening + rhs_axis_boundary: Axis split point for rhs N-D → 2D flattening + rhs_group_axis: Batch-group axis of rhs to exclude from output non-contracting dims Returns: A jnp.ndarray containing the result of the grouped GEMM operation @@ -1456,21 +1454,33 @@ def abstract( num_groups, ) - # lhs_data_aval and rhs_data_aval are 2D; derive output shape from buffer dims. - # lhs shape: [M, K] (lhs_is_trans=False) or [K, M] (lhs_is_trans=True) - # rhs shape: [G*K, N] or [K, N] (rhs_is_trans=False) or [G*N, K] (rhs_is_trans=True) - M = _grouped_gemm_lhs_M(lhs_data_aval.shape, lhs_is_trans) + # Derive output shape from N-D buffer shapes using axis_boundary. + lhs_shape = lhs_data_aval.shape + rhs_shape = rhs_data_aval.shape + + # Non-contracting dims for lhs + if lhs_is_trans: + lhs_non_contracting = lhs_shape[lhs_axis_boundary:] + else: + lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + + # Non-contracting dims for rhs (excluding batch-group axis where applicable) + if rhs_is_trans: + rhs_non_contracting = tuple( + rhs_shape[d] + for d in range(rhs_axis_boundary) + if rhs_group_axis is None or d != rhs_group_axis + ) + else: + rhs_non_contracting = rhs_shape[rhs_axis_boundary:] + # K validation is intentionally skipped: per-group K values may not fill the # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: - # Wgrad case: rhs has ragged contracting K dimension with no G-prefix. - # T-layout rhs shape is (N, K_total); N-layout rhs shape is (K_total, N). - N = rhs_data_aval.shape[0] if rhs_is_trans else rhs_data_aval.shape[1] - out_shape = (num_groups, M, N) + # Wgrad case: rhs has ragged contracting K dimension → output gets G prefix. + out_shape = (num_groups, *lhs_non_contracting, *rhs_non_contracting) else: - # When rhs has a leading group axis, _grouped_gemm_rhs_N divides by num_groups. - N = _grouped_gemm_rhs_N(rhs_data_aval.shape, rhs_is_trans, num_groups) - out_shape = (M, N) + out_shape = (*lhs_non_contracting, *rhs_non_contracting) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1577,8 +1587,11 @@ def lowering( has_bias, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): - del out_dtype + del out_dtype, rhs_group_axis # Python-only; not forwarded to C++ if use_v2_ffi: ffi_name = GroupedGemmPrimitive.name_graph_safe return jax.ffi.ffi_lowering(ffi_name)( @@ -1587,6 +1600,8 @@ def lowering( lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( @@ -1597,6 +1612,8 @@ def lowering( scaling_mode=scaling_mode.value, has_bias=has_bias, use_async_d2h_group_sizes=use_async_d2h_group_sizes, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) @staticmethod @@ -1621,6 +1638,9 @@ def impl( has_bias, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): if GroupedGemmPrimitive.inner_primitive is None: raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") @@ -1648,6 +1668,9 @@ def impl( has_bias=has_bias, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + rhs_group_axis=rhs_group_axis, ) return (out,) @@ -1959,27 +1982,9 @@ def _can_use_v2_grouped_gemm( return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias -def _flatten_to_2d(data, flatten_axis): - """Reshape *data* to 2D by splitting at *flatten_axis*. - - Positive flatten_axis: split before that axis index. - Negative flatten_axis: split before (ndim + flatten_axis). - """ - if data.ndim == 2: - return data # Already 2D, no reshape needed - fa = flatten_axis if flatten_axis >= 0 else data.ndim + flatten_axis - return data.reshape(math.prod(data.shape[:fa]), math.prod(data.shape[fa:])) - - def grouped_gemm( - lhs: Union[jnp.ndarray, GroupedScaledTensor1x], - rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) - lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) - rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) - rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) - out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) - out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -1992,14 +1997,8 @@ def grouped_gemm( Grouped GEMM operation. Args: - lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - lhs_first_dims: (G,) int32 if LHS squashed first dim varies per group, else None/(0,) - lhs_last_dims: (G,) int32 if LHS squashed last dim varies per group, else None/(0,) - rhs_first_dims: (G,) int32 if RHS squashed first dim varies per group (wgrad), else None/(0,) - rhs_last_dims: (G,) int32 if RHS squashed last dim varies per group, else None/(0,) - out_first_dims: (G,) int32 if output first dim varies per group, else None/(0,) - out_last_dims: (G,) int32 if output last dim varies per group, else None/(0,) + lhs: Left-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x + rhs: Right-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -2009,60 +2008,76 @@ def grouped_gemm( Returns: A jnp.ndarray containing the result of the grouped GEMM operation - - Note: - Tested shapes: - lhs: [M, K] or [K, N] - rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ # TODO(Phuong): implement the precision del precision - # Replace None sentinels with empty (0,) int32 arrays. empty_gs = jnp.empty((0,), jnp.int32) - lhs_first_dims = empty_gs if lhs_first_dims is None else lhs_first_dims - lhs_last_dims = empty_gs if lhs_last_dims is None else lhs_last_dims - rhs_first_dims = empty_gs if rhs_first_dims is None else rhs_first_dims - rhs_last_dims = empty_gs if rhs_last_dims is None else rhs_last_dims - out_first_dims = empty_gs if out_first_dims is None else out_first_dims - out_last_dims = empty_gs if out_last_dims is None else out_last_dims - - if isinstance(lhs, jnp.ndarray): - if not isinstance(rhs, jnp.ndarray): - raise TypeError( - f"Expected rhs to be jnp.ndarray when lhs is jnp.ndarray, but got type={type(rhs)}" - ) - out_dtype = lhs.dtype - lhs_shape = lhs.shape - rhs_shape = rhs.shape - lhs_data = lhs - rhs_data = rhs - lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + + # Extract data, dims, and metadata from tensor objects. + if isinstance(lhs, GroupedNoScaleTensor): + lhs_data = lhs.data + lhs_shape = lhs.original_shape + lhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING + out_dtype = lhs.data.dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + rhs_group_axis = getattr(rhs, "group_axis", 0) elif isinstance(lhs, GroupedScaledTensor1x): - if not isinstance(rhs, GroupedScaledTensor1x): - raise TypeError( - "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but" - f" got type={type(rhs)}" - ) - out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape - rhs_shape = rhs.original_shape - lhs_fa = lhs.flatten_axis - rhs_fa = rhs.flatten_axis - lhs_data = lhs.data.reshape(math.prod(lhs_shape[:lhs_fa]), math.prod(lhs_shape[lhs_fa:])) - rhs_data = rhs.data.reshape(math.prod(rhs_shape[:rhs_fa]), math.prod(rhs_shape[rhs_fa:])) + lhs_data = lhs.data.reshape(lhs_shape) lhs_scale_inv = lhs.scale_inv + scaling_mode = lhs.scaling_mode + out_dtype = lhs.dq_dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + rhs_group_axis = getattr(rhs, "group_axis", 0) + else: + raise TypeError( + "lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " + f"got type={type(lhs)}" + ) + + if isinstance(rhs, GroupedNoScaleTensor): + rhs_data = rhs.data + rhs_shape = rhs.original_shape + rhs_scale_inv = jnp.empty((0,), jnp.float32) + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + elif isinstance(rhs, GroupedScaledTensor1x): + rhs_shape = rhs.original_shape + rhs_data = rhs.data.reshape(rhs_shape) rhs_scale_inv = rhs.scale_inv - if lhs.scaling_mode != rhs.scaling_mode: + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: raise ValueError( f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," f" rhs.scaling_mode={rhs.scaling_mode}" ) - scaling_mode = lhs.scaling_mode + if isinstance(lhs, GroupedScaledTensor1x): + scaling_mode = lhs.scaling_mode else: - raise TypeError("Unsupported lhs type object!") + raise TypeError( + "rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " + f"got type={type(rhs)}" + ) + + # Infer output dims from which operand has the ragged non-contracting dim. + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) + out_first_dims = empty_gs + out_last_dims = empty_gs + elif lhs_first_dims.size > 0: + out_first_dims = lhs_first_dims + out_last_dims = empty_gs + elif lhs_last_dims.size > 0: + out_first_dims = empty_gs + out_last_dims = lhs_last_dims + else: + out_first_dims = out_last_dims = empty_gs out_dtype = preferred_element_type or out_dtype @@ -2072,8 +2087,6 @@ def grouped_gemm( lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). - # This formula handles both standard rhs [G, K, N] (G-prefixed) and wgrad - # rhs [K_total, N] (no G prefix) without needing a separate wgrad override. rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) @@ -2115,29 +2128,18 @@ def grouped_gemm( ), empty_gs, ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, active_group_sizes, lhs_flatten_axis) + lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data + rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data + lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) rhs_q = grouped_quantize( - rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis - ) - # grouped_quantize returns a 1D flat buffer; reshape to 2D using the - # original_shape and flatten_axis stored in each quantized tensor. - lhs_fa = lhs_q.flatten_axis # positive index (adjusted in create_1x) - rhs_fa = rhs_q.flatten_axis - lhs_data = lhs_q.data.reshape( - math.prod(lhs_q.original_shape[:lhs_fa]), - math.prod(lhs_q.original_shape[lhs_fa:]), - ) - rhs_data = rhs_q.data.reshape( - math.prod(rhs_q.original_shape[:rhs_fa]), - math.prod(rhs_q.original_shape[rhs_fa:]), + rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) + lhs_data = lhs_q.data.reshape(lhs_q.original_shape) + rhs_data = rhs_q.data.reshape(rhs_q.original_shape) lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape rhs_shape = rhs_q.original_shape - # Data is already 2D; reset flatten axes so _flatten_to_2d calls below are no-ops. - lhs_flatten_axis = -1 - rhs_flatten_axis = -1 if lhs_data.dtype == jnp.float8_e5m2 and rhs_data.dtype == jnp.float8_e5m2: raise ValueError("FP8 GEMM does not support E5M2 * E5M2") @@ -2174,9 +2176,9 @@ def grouped_gemm( else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Reshape inputs to 2D using the already-computed flatten_axes. - lhs_data_2d = _flatten_to_2d(lhs_data, lhs_flatten_axis) - rhs_data_2d = _flatten_to_2d(rhs_data, rhs_flatten_axis) + # Compute N-D axis boundaries from final (post-adjustment) contracting dims. + lhs_axis_boundary = get_lhs_axis_boundary(lhs_contract_dim, lhs_is_trans) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_contract_dim, rhs_is_trans) num_gemms = ( lhs_first_dims.size @@ -2188,14 +2190,21 @@ def grouped_gemm( ) if num_gemms == 0: raise ValueError( - "grouped_gemm requires at least one non-empty dimension array " - "(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, " - "out_first_dims, or out_last_dims)." + "grouped_gemm requires at least one non-empty dimension array. " + "Ensure lhs or rhs tensor objects carry first_dims or last_dims." ) has_bias = bias is not None if has_bias: - N_dim = rhs_data_2d.shape[0] // num_gemms if rhs_is_trans else rhs_data_2d.shape[1] + # Compute N from rhs non-contracting dims. + if rhs_is_trans: + N_dim = math.prod( + rhs_data.shape[d] + for d in range(rhs_axis_boundary) + if rhs_group_axis is None or d != rhs_group_axis + ) + else: + N_dim = math.prod(rhs_data.shape[rhs_axis_boundary:]) assert bias.shape == ( num_gemms, N_dim, @@ -2218,9 +2227,9 @@ def grouped_gemm( additional_arg_1 = jnp.zeros((0,), jnp.int32) # unused placeholder (out,) = GroupedGemmPrimitive.outer_primitive.bind( - lhs_data_2d, + lhs_data, lhs_scale_inv, - rhs_data_2d, + rhs_data, rhs_scale_inv, bias, lhs_first_dims, @@ -2238,5 +2247,8 @@ def grouped_gemm( has_bias=has_bias, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + rhs_group_axis=rhs_group_axis, ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index bf4e833c89..c8578d48b8 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1203,6 +1203,7 @@ def grouped_quantize( ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}" group_axis = 0 + ragged_first_dims = group_sizes # None if no explicit group_sizes (kernel case) if group_sizes is None: group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) @@ -1280,7 +1281,7 @@ def grouped_quantize( q_layout=quantizer.q_layout, data_layout=quantizer.get_data_layout(), flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=ragged_first_dims, original_shape=original_shape, group_axis=group_axis, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index bd429a7db6..616209709b 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -59,6 +59,8 @@ struct GroupedGemmV2Config { bool lhs_is_trans; bool rhs_is_trans; JAXX_Scaling_Mode scaling_mode; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; }; struct GroupedGemmConfig { @@ -67,6 +69,8 @@ struct GroupedGemmConfig { JAXX_Scaling_Mode scaling_mode; bool has_bias; bool use_async_d2h_group_sizes; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; }; inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } @@ -209,14 +213,18 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), - ::xla::ffi::StructMember("scaling_mode")); + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary")); XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), ::xla::ffi::StructMember("rhs_is_trans"), ::xla::ffi::StructMember("scaling_mode"), ::xla::ffi::StructMember("has_bias"), - ::xla::ffi::StructMember("use_async_d2h_group_sizes")); + ::xla::ffi::StructMember("use_async_d2h_group_sizes"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary")); // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 2d73390d33..50bf43f349 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -627,13 +627,15 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream) { + size_t num_gemms, cudaStream_t stream, + int64_t axis_boundary = -1) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); - // Flatten all leading dimensions into the first axis to produce a 2D NVTE shape. - // Input buffers (lhs, rhs) are already 2D from the Python side. Output buffers may be ND - // (e.g. [G, K, N] for wgrad), so we collapse dims[0..N-2] → rows and keep dims[N-1] → cols. - NVTEShape dataShape{.data = {product(dims, 0, dims.size() - 1), dims[dims.size() - 1]}, + // Flatten dims at axis_boundary to produce a 2D NVTE shape. + // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, + // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); @@ -698,7 +700,7 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Buffer_Type alpha, Buffer_Type beta, Result_Type output, Result_Type cublas_workspace, Result_Type setup_workspace, Result_Type int64_workspace, GroupedGemmV2Config config) { - auto [lhs_is_trans, rhs_is_trans, scaling_mode] = config; + auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary] = config; NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, "Only non-quantized grouped GEMM is supported in current implementation."); @@ -732,9 +734,11 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + int64_capacity, int64_offset, num_gemms, stream, + rhs_axis_boundary); auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream); + int64_capacity, int64_offset, num_gemms, stream, + lhs_axis_boundary); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); @@ -777,7 +781,8 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type Buffer_Type out_first_dims, Buffer_Type out_last_dims, Buffer_Type group_offset, Result_Type output, Result_Type workspace, GroupedGemmConfig config) { - auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes] = config; + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes, + lhs_axis_boundary, rhs_axis_boundary] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -827,20 +832,26 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type else if (is_rhs_last_ragged) active_gs_ptr = &rhs_last_dims; - // lhs_data and rhs_data are 2D; derive m, n, k from buffer dimensions. - NVTE_CHECK(lhs_data.dimensions().size() == 2, "lhs_data must be 2D."); - NVTE_CHECK(rhs_data.dimensions().size() == 2, "rhs_data must be 2D."); - size_t k = lhs_is_trans ? lhs_data.dimensions()[0] : lhs_data.dimensions()[1]; + // Derive m, n, k from N-D buffer dimensions using axis_boundary. + // axis_boundary splits contracting dims from non-contracting dims. + auto lhs_dims = lhs_data.dimensions(); + auto rhs_dims = rhs_data.dimensions(); + NVTE_CHECK(lhs_dims.size() >= 2, "lhs_data must be at least 2D."); + NVTE_CHECK(rhs_dims.size() >= 2, "rhs_data must be at least 2D."); + size_t lab = static_cast(lhs_axis_boundary); + size_t rab = static_cast(rhs_axis_boundary); + // k = product of contracting dims of lhs + size_t k = lhs_is_trans ? product(lhs_dims, 0, lab) : product(lhs_dims, lab, lhs_dims.size()); size_t m, n; if (is_rhs_ragged) { - // wgrad: lhs shape [K_lhs, M]: lhs_is_trans=True, contracting is dim[0]=K_lhs, output is dim[1]=M - m = lhs_is_trans ? lhs_data.dimensions()[1] : lhs_data.dimensions()[0]; - // T-layout rhs: (N, K_total) -> n = dim[0]; N-layout rhs: (K_total, N) -> n = dim[1] - n = rhs_is_trans ? rhs_data.dimensions()[0] : rhs_data.dimensions()[1]; + // wgrad: non-contracting lhs dims form M; non-contracting rhs dims form N + m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) : product(lhs_dims, 0, lab); + n = rhs_is_trans ? product(rhs_dims, 0, rab) : product(rhs_dims, rab, rhs_dims.size()); } else { - m = lhs_is_trans ? lhs_data.dimensions()[1] - : lhs_data.dimensions()[0]; // total M (sum of group sizes) - n = rhs_is_trans ? rhs_data.dimensions()[0] / num_gemms : rhs_data.dimensions()[1]; + m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) + : product(lhs_dims, 0, lab); // total M (sum of group sizes) + n = rhs_is_trans ? product(rhs_dims, 0, rab) / num_gemms + : product(rhs_dims, rab, rhs_dims.size()); } // Inputs diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index 8b397520f2..76c984486f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -27,6 +27,7 @@ is_fp8_gemm_with_all_layouts_supported, TensorUsage, QuantizeLayout, + GroupedNoScaleTensor, ) @@ -490,7 +491,8 @@ def _grouped_dense_fwd_rule( is_colwise=False, data_layout="N", flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, + first_dims=ctx_kernel.first_dims, + last_dims=ctx_kernel.last_dims, original_shape=kernel_shape, group_axis=ctx_kernel.group_axis, ) @@ -507,7 +509,8 @@ def _grouped_dense_fwd_rule( is_colwise=True, data_layout="T", flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, + first_dims=ctx_kernel.first_dims, + last_dims=ctx_kernel.last_dims, original_shape=kernel_shape, group_axis=ctx_kernel.group_axis, ) @@ -518,16 +521,24 @@ def _grouped_dense_fwd_rule( # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout - empty_gs = jnp.empty((0,), jnp.int32) + if is_noop_quantizer_set: + grouped_gemm_x = GroupedNoScaleTensor( + data=grouped_gemm_x, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=grouped_gemm_x.shape, + ) + grouped_gemm_kernel = GroupedNoScaleTensor( + data=grouped_gemm_kernel, + first_dims=None, + last_dims=None, + group_axis=0, + original_shape=grouped_gemm_kernel.shape, + ) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, contracting_dims=contracting_dims, bias=bias, precision=precision, @@ -616,16 +627,38 @@ def _grouped_dense_bwd_rule( wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) - empty_gs = jnp.empty((0,), jnp.int32) + if is_noop_quantizer_set: + dgrad_grad = GroupedNoScaleTensor( + data=dgrad_grad, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=dgrad_grad.shape, + ) + dgrad_kernel_T = GroupedNoScaleTensor( + data=dgrad_kernel_T, + first_dims=None, + last_dims=None, + group_axis=0, + original_shape=dgrad_kernel_T.shape, + ) + wgrad_x_T = GroupedNoScaleTensor( + data=wgrad_x_T, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=wgrad_x_T.shape, + ) + wgrad_grad = GroupedNoScaleTensor( + data=wgrad_grad, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=wgrad_grad.shape, + ) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=empty_gs, - rhs_last_dims=empty_gs, - out_first_dims=group_sizes, - out_last_dims=empty_gs, contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, @@ -635,12 +668,6 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - lhs_first_dims=group_sizes, - lhs_last_dims=empty_gs, - rhs_first_dims=group_sizes, - rhs_last_dims=empty_gs, - out_first_dims=empty_gs, - out_last_dims=empty_gs, contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 74787b9308..8fd54a3a63 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -275,7 +275,16 @@ def _grouped_dequantize(grouped_scaled_tensor): """ data = grouped_scaled_tensor.data scale_inv = grouped_scaled_tensor.scale_inv - group_sizes = grouped_scaled_tensor.group_sizes + group_sizes = ( + grouped_scaled_tensor.first_dims + if grouped_scaled_tensor.first_dims is not None and grouped_scaled_tensor.first_dims.size > 0 + else grouped_scaled_tensor.last_dims + ) + # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape + if group_sizes is None: + group_sizes = jnp.ones( + grouped_scaled_tensor.original_shape[grouped_scaled_tensor.group_axis], dtype=jnp.int32 + ) flatten_axis = grouped_scaled_tensor.flatten_axis scaling_mode = grouped_scaled_tensor.scaling_mode original_shape = grouped_scaled_tensor.original_shape diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..55dd7f5618 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -948,7 +948,7 @@ def _create_grouped_tensor_from_tensor_list( is_colwise=tensor_list[0].is_colwise, data_layout=tensor_list[0].data_layout, flatten_axis=tensor_list[0].flatten_axis, - group_sizes=group_sizes, + first_dims=group_sizes, original_shape=original_shape, group_axis=group_axis, ) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c26cb8a531..38433e95ae 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -9,7 +9,7 @@ rowwise and colwise quantization modes with proper scaling and dequantization. """ from dataclasses import dataclass -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple from abc import ABC, abstractmethod import jax.numpy as jnp @@ -32,6 +32,7 @@ "ScaledTensor1x", "ScaledTensor2x", "GroupedScaledTensor1x", + "GroupedNoScaleTensor", "ScaledTensorFactory", "with_sharding_constraint_by_logical_axes", ] @@ -365,12 +366,14 @@ class GroupedScaledTensor1x(ScaledTensor1x): where elements are grouped along a specified axis. Attributes: - group_sizes: Array containing the size of each group + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping group_axis: The axis along which grouping is performed (default: 0) """ - group_sizes: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] original_shape: Tuple group_axis: int @@ -379,7 +382,7 @@ def __init__( data, scale_inv, amax, - group_sizes, + first_dims, scaling_mode, dq_dtype, _dq_func, @@ -388,9 +391,11 @@ def __init__( flatten_axis, original_shape, group_axis=0, + last_dims=None, ): self.flatten_axis = flatten_axis - self.group_sizes = group_sizes + self.first_dims = first_dims + self.last_dims = last_dims self.original_shape = original_shape self.group_axis = group_axis # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 @@ -422,9 +427,19 @@ def __post_init__(self): 0 <= self.group_axis < data_ndim ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" + active_dims = ( + self.first_dims + if self.first_dims is not None and self.first_dims.size > 0 + else self.last_dims + ) + if active_dims is not None: + num_groups = active_dims.size + else: + num_groups = self.original_shape[self.group_axis] + expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( self.original_shape, - self.group_sizes.size, + num_groups, self.group_axis, self.is_colwise, is_padded=True, @@ -442,7 +457,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax, self.group_sizes) + children = (self.data, self.scale_inv, self.amax, self.first_dims, self.last_dims) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -455,6 +470,36 @@ def tree_flatten(self): ) return (children, aux_data) + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the tensor from its flattened representation.""" + data, scale_inv, amax, first_dims, last_dims = children + ( + scaling_mode, + dq_dtype, + _dq_func, + is_colwise, + data_layout, + flatten_axis, + original_shape, + group_axis, + ) = aux_data + return cls( + data=data, + scale_inv=scale_inv, + amax=amax, + first_dims=first_dims, + last_dims=last_dims, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, + original_shape=original_shape, + group_axis=group_axis, + ) + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): raise NotImplementedError @@ -473,6 +518,52 @@ def checkpoint(self, quantizer): return jax_checkpoint_name(self, name=quantizer.checkpoint_name) +@register_pytree_node_class +@dataclass +class GroupedNoScaleTensor: + """Unquantized grouped tensor. + + Stores N-D data with per-group dimension sizes so that grouped_gemm() + can extract first/last dims automatically without explicit parameters. + + Attributes: + data: The raw (unquantized) tensor data in N-D layout + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged + group_axis: Which axis of original_shape is the group batch prefix + original_shape: Shape of data (same as data.shape for N-D unquantized) + """ + + data: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] + group_axis: int + original_shape: Tuple + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations.""" + children = (self.data, self.first_dims, self.last_dims) + aux_data = (self.group_axis, self.original_shape) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the tensor from its flattened representation.""" + group_axis, original_shape = aux_data + data, first_dims, last_dims = children + return cls( + data=data, + first_dims=first_dims, + last_dims=last_dims, + group_axis=group_axis, + original_shape=original_shape, + ) + + def dequantize(self): + """No-op dequantization — returns the raw data.""" + return self.data + + @register_pytree_node_class @dataclass class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): @@ -570,7 +661,8 @@ def create_1x( is_colwise=False, data_layout="N", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, group_axis=0, has_rht_applied=False, @@ -586,29 +678,40 @@ def create_1x( is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor - group_sizes: Array of ints containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: - A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided + A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether first_dims or last_dims is provided """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if group_sizes is not None: - flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + if first_dims is not None or last_dims is not None or ( + original_shape is not None and group_axis is not None + ): assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + + # Determine num_groups from whichever dims array is provided, or from original_shape + active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + if active_dims is not None: + num_groups = active_dims.size + else: + norm_group_axis = (len(original_shape) + group_axis) % len(original_shape) + num_groups = original_shape[norm_group_axis] # Handling attrs of transposed tensors group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": - if original_shape[0] == group_sizes.size: + if original_shape[0] == num_groups: original_shape = ( original_shape[0], *original_shape[flatten_axis:], @@ -633,7 +736,8 @@ def create_1x( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, ) @@ -668,7 +772,8 @@ def create_2x( dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, group_axis=0, rowwise_has_rht_applied=False, @@ -686,7 +791,8 @@ def create_2x( dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") flatten_axis: The quantization axis for the tensor - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -710,7 +816,8 @@ def create_2x( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, @@ -724,7 +831,8 @@ def create_2x( is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, @@ -744,7 +852,8 @@ def create( data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, flatten_axis: int = -1, - group_sizes: jnp.ndarray = None, + first_dims: jnp.ndarray = None, + last_dims: jnp.ndarray = None, original_shape: Tuple[int] = None, group_axis: int = 0, rowwise_has_rht_applied: bool = False, @@ -762,7 +871,8 @@ def create( data_layout: The data_layout specification (default: "NN") q_layout: The quantization axis (default: ROWWISE) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -785,7 +895,8 @@ def create( dq_dtype, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, rowwise_has_rht_applied=rowwise_has_rht_applied, @@ -802,7 +913,8 @@ def create( is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, @@ -817,7 +929,8 @@ def create( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, From 1d1fec90cfd3e05e195010512652a8c78e72c19e Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 13 Mar 2026 08:42:34 -0700 Subject: [PATCH 14/23] MXFP8 grouped quantize V2 Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 46 ++++++++++++++++--- transformer_engine/jax/csrc/extensions.h | 2 + .../jax/csrc/extensions/pybind.cpp | 1 + 3 files changed, 43 insertions(+), 6 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index c8578d48b8..f7c0e796cc 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -993,7 +993,8 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" + name = "te_grouped_quantize_ffi" # V1: non-MXFP8 + name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( 3, @@ -1050,7 +1051,17 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING + if is_mxfp8: + # V2 path: 5th output is int64_workspace (n_groups * sizeof(int64_t) bytes as uint8) + fifth_out_aval = jax.core.ShapedArray( + shape=(group_sizes_aval.size * 8,), dtype=jnp.uint8 + ) + else: + # V1 path: 5th output is amax + fifth_out_aval = jax.core.ShapedArray( + shape=(group_sizes_aval.size,), dtype=jnp.float32 + ) if q_layout.has_colwise: colwise_out_shape = out_shape @@ -1070,7 +1081,7 @@ def abstract( colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, - amax_aval, + fifth_out_aval, ) @staticmethod @@ -1084,9 +1095,15 @@ def outer_abstract(*args, **kwargs): colwise_out, scale_inv, colwise_scale_inv, - updated_amax, + fifth_out, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax + # For MXFP8, the inner abstract returns int64_workspace as the 5th output. + # The outer interface always presents amax (float32, n_groups) for a consistent API. + scaling_mode = kwargs.get("scaling_mode") + group_sizes_aval = args[2] + if ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING: + fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out @staticmethod def lowering( @@ -1111,6 +1128,17 @@ def lowering( assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 assert group_axis == 0 + is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING + if is_mxfp8: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( + ctx, + x, + scale, + group_sizes, + q_layout=q_layout.value.value, + flatten_axis=flatten_axis, + ) return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1142,7 +1170,7 @@ def impl( colwise_out, rowwise_scale_inv, colwise_scale_inv, - updated_amax, + fifth, ) = GroupedQuantizePrimitive.inner_primitive.bind( x, scale, @@ -1154,6 +1182,12 @@ def impl( group_axis=group_axis, scale_dtype=scale_dtype, ) + is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING + if is_mxfp8: + # fifth is int64_workspace; return a dummy zero amax for interface compatibility + updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) + else: + updated_amax = fifth return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 616209709b..c832b4ebb2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -111,6 +111,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeV2Handler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..e3bc122403 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -33,6 +33,7 @@ pybind11::dict Registrations() { // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); + dict["te_grouped_quantize_v2_ffi"] = EncapsulateFFI(GroupedQuantizeV2Handler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax From 269a5186715b3f63aa21df7912080da3ca0284d2 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Fri, 13 Mar 2026 15:46:34 +0000 Subject: [PATCH 15/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- transformer_engine/jax/cpp_extensions/gemm.py | 11 ++++---- .../jax/cpp_extensions/quantization.py | 6 ++--- .../jax/csrc/extensions/gemm.cpp | 26 ++++++++----------- .../jax/quantize/dequantizer.py | 3 ++- transformer_engine/jax/quantize/tensor.py | 10 ++++--- 5 files changed, 27 insertions(+), 29 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index ff9194bdd9..c86cb1db55 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -1332,7 +1332,6 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) - def _assert_grouped_gemm_dims_shapes( lhs_first_dims_aval, lhs_last_dims_aval, @@ -2036,8 +2035,7 @@ def grouped_gemm( rhs_group_axis = getattr(rhs, "group_axis", 0) else: raise TypeError( - "lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " - f"got type={type(lhs)}" + f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" ) if isinstance(rhs, GroupedNoScaleTensor): @@ -2061,8 +2059,7 @@ def grouped_gemm( scaling_mode = lhs.scaling_mode else: raise TypeError( - "rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, " - f"got type={type(rhs)}" + f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" ) # Infer output dims from which operand has the ragged non-contracting dim. @@ -2130,7 +2127,9 @@ def grouped_gemm( ) lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data - lhs_q = grouped_quantize(lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis) + lhs_q = grouped_quantize( + lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis + ) rhs_q = grouped_quantize( rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index f7c0e796cc..cb506160bf 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -993,7 +993,7 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" # V1: non-MXFP8 + name = "te_grouped_quantize_ffi" # V1: non-MXFP8 name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( @@ -1059,9 +1059,7 @@ def abstract( ) else: # V1 path: 5th output is amax - fifth_out_aval = jax.core.ShapedArray( - shape=(group_sizes_aval.size,), dtype=jnp.float32 - ) + fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) if q_layout.has_colwise: colwise_out_shape = out_shape diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 50bf43f349..07adf55577 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -622,21 +622,17 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked // before each slot is used. Only NO_SCALING is supported. -JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, - Buffer_Type const &first_dims, - Buffer_Type const &last_dims, - int64_t *int64_workspace_base, - size_t int64_workspace_capacity, size_t &int64_offset, - size_t num_gemms, cudaStream_t stream, - int64_t axis_boundary = -1) { +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { auto dims = data.dimensions(); NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); // Flatten dims at axis_boundary to produce a 2D NVTE shape. // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); - NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, - .ndim = 2}; + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); wrapper.set_rowwise(data, std::nullopt); if (first_dims.element_count() > 0) { @@ -733,12 +729,12 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - rhs_axis_boundary); - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, - int64_capacity, int64_offset, num_gemms, stream, - lhs_axis_boundary); + auto rhs_tensor = + make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 8fd54a3a63..5075f1a664 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -277,7 +277,8 @@ def _grouped_dequantize(grouped_scaled_tensor): scale_inv = grouped_scaled_tensor.scale_inv group_sizes = ( grouped_scaled_tensor.first_dims - if grouped_scaled_tensor.first_dims is not None and grouped_scaled_tensor.first_dims.size > 0 + if grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 else grouped_scaled_tensor.last_dims ) # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 38433e95ae..316e4f3139 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -692,8 +692,10 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if first_dims is not None or last_dims is not None or ( - original_shape is not None and group_axis is not None + if ( + first_dims is not None + or last_dims is not None + or (original_shape is not None and group_axis is not None) ): assert ( original_shape is not None @@ -701,7 +703,9 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) if active_dims is not None: num_groups = active_dims.size else: From b2b3216fcb5f188b01bc36ca89e9c4be381f6093 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Fri, 13 Mar 2026 19:02:35 -0700 Subject: [PATCH 16/23] MXFP8 quantization working Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 148 +++++++++++++++- .../cast/mxfp8/group_quantize_mxfp8.cuh | 4 + .../common/gemm/cublaslt_grouped_gemm.cu | 39 +++++ .../common/include/transformer_engine/gemm.h | 30 ++++ .../jax/cpp_extensions/quantization.py | 65 +++++-- .../jax/csrc/extensions/gemm.cpp | 16 +- .../jax/csrc/extensions/quantization.cpp | 164 ++++++++++++++++++ transformer_engine/jax/flax/module.py | 9 +- transformer_engine/jax/quantize/tensor.py | 22 ++- 9 files changed, 464 insertions(+), 33 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 9fddbc435c..db44621d80 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1737,7 +1737,9 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): ref_out.append(jnp.squeeze(out_i)) return ref_out - def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): + def _generate_grouped_dense_input( + self, dtype, input_shape, data_layout="NN", with_bias=False, group_size_multiplier=32 + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape @@ -1750,9 +1752,12 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m - # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 + # Scale group sizes by the multiplier. + # Use group_size_multiplier=128 for MXFP8 V2 tests so that each group's row count + # is divisible by 128, satisfying the V2 kernel's per-group alignment requirement. + # Use group_size_multiplier=32 for V1 tests or non-MXFP8 tests. + group_sizes = group_sizes * group_size_multiplier + m = m * group_size_multiplier lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) @@ -1826,8 +1831,10 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout quantizer.q_dtype = bwd_dtype out_dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - out_dtype, input_shape, layout + out_dtype, input_shape, layout, group_size_multiplier=128 if is_mxfp8 else 32 ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) @@ -1901,10 +1908,13 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, + group_size_multiplier=128 if is_mxfp8 else 32, ) quantizer_set = QuantizerFactory.create_set( @@ -1938,6 +1948,134 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dbias, ref_dbias, dtype=dtype) +# MXFP8 V1 shapes: lhs total_rows = m * 32 and rhs total_rows = n_groups * k are +# NOT divisible by 128, forcing the V1 (non-CUDA-graph-safe) kernel. +GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES = [ + # (n_groups, m, n, k) + # lhs total_rows = m * 32; rhs total_rows = n_groups * k + (5, 6, 128, 64), # lhs: 6*32=192 (not 128-aligned); rhs: 5*64=320 (not 128-aligned) +] + +# MXFP8 V2 shapes: lhs total_rows = m * 128 and rhs total_rows = n_groups * k are +# divisible by 128, allowing the V2 (CUDA-graph-safe) kernel to be used. +# These shapes must be paired with group_size_multiplier=128 so that each group's +# row count is also divisible by 128 (the V2 per-group alignment requirement). +GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES = [ + # (n_groups, m, n, k) + # lhs total_rows = m * 128; rhs total_rows = n_groups * k + (8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned) + (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) +] + + +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) +class TestGroupedDenseMXFP8KernelSelection: + """Tests that explicitly verify V1 and V2 MXFP8 grouped quantize kernel selection. + + V2 is the CUDA-graph-safe kernel and requires: + - total_first_dim (= product of input shape up to flatten_axis) % 128 == 0 + - each individual group_size % 128 == 0 (enforced by the kernel at runtime) + V1 is the fallback that supports arbitrary shapes but performs a D2H copy of + group_sizes (not CUDA-graph safe). + """ + + def _generate_mxfp8_input(self, input_shape, group_size_multiplier): + """Generate inputs with the given group_size_multiplier for MXFP8 tests.""" + key = jax.random.PRNGKey(42) + subkeys = jax.random.split(key, 3) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) + group_sizes = group_sizes * group_size_multiplier + m_total = m * group_size_multiplier + + lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16) + rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16) + return lhs, rhs, group_sizes + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES, + ids=[f"v1_{s}" for s in GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES], + ) + def test_grouped_gemm_mxfp8_v1_shapes(self, input_shape): + """MXFP8 grouped GEMM with V1-only shapes (total_first_dim not 128-aligned).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=32) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=input_shape[0], + ) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) + # Reference: unquantized grouped GEMM + n_groups = input_shape[0] + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + # Check output has correct shape and dtype; numerical precision is expected to be lower + # due to FP8 quantization but the result should be finite. + assert prim_out.shape == ref_out.shape + assert jnp.all(jnp.isfinite(prim_out)) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, + ids=[f"v2_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], + ) + def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): + """MXFP8 grouped GEMM with V2-eligible shapes (total_first_dim 128-aligned, + group_sizes also 128-aligned).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=input_shape[0], + ) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) + n_groups = input_shape[0] + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + assert prim_out.shape == ref_out.shape + assert jnp.all(jnp.isfinite(prim_out)) + + class TestDebugInspectFFI: @pytest_parametrize_wrapper("shape", [(256, 128)]) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..d7eaf028e0 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -192,6 +192,10 @@ __global__ void update_tma_descriptors( const size_t offset_elts = offsets_ptr[tensor_id]; if (leading_thread && (tensor_id < num_tensors)) { + // Zero-sized groups: skip TMA descriptor update. The main kernel already returns + // early for rows==0 or cols==0, but creating a TMA descriptor with a zero dimension + // is invalid and causes CUDA_ERROR_ILLEGAL_ADDRESS. + if (rows == 0 || cols == 0) return; { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index dc4757ab90..01f5361481 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -682,6 +682,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, if (idx < n) dst[idx] = static_cast(src[idx]); } +// Like convert_int32_to_int64_kernel but scales each element by multiplier. +// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. +__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, + size_t n, int64_t multiplier) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; +} + +// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim). +// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small. +__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim) { + offsets[0] = 0; + for (size_t i = 0; i < n_groups; i++) { + offsets[i + 1] = offsets[i] + first_dims[i] * last_dim; + } +} + } // namespace void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { @@ -692,3 +710,24 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud convert_int32_to_int64_kernel<<>>(src, dst, n); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, + multiplier); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, + cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); + // Always write at least offsets[0]=0 (needed even for n_groups==0). + compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, + last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 0f3b0ebd6b..28490653bc 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -354,6 +354,36 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors); */ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); +/*! \brief Convert int32 array to int64 while scaling each element by a multiplier. + * + * Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n). + * CUDA-graph safe (no host-device synchronization). + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] multiplier Scale factor applied to each element. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream); + +/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. + * + * Writes n_groups+1 values to offsets: offsets[0]=0, + * offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups]. + * This is CUDA-graph safe (no host-device synchronization). + * + * \param[in] first_dims Device pointer to int64 array of length n_groups. + * \param[out] offsets Device pointer to int64 array of length n_groups+1. + * \param[in] n_groups Number of groups. + * \param[in] last_dim Common last dimension (number of columns). + * \param[in] stream CUDA stream. + */ +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, + cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index cb506160bf..02639cfa2d 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -993,7 +993,7 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" # V1: non-MXFP8 + name = "te_grouped_quantize_ffi" # V1: fallback path (supports all shapes, not CUDA-graph safe) name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( @@ -1007,6 +1007,38 @@ class GroupedQuantizePrimitive(BasePrimitive): inner_primitive = None outer_primitive = None + @staticmethod + def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): + """Return True when the V2 (CUDA-graph-safe) MXFP8 kernel can be used. + + V2 requires: + 1. The total first logical dimension (product of x_shape up to flatten_axis) + is divisible by 128. + 2. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the + per-group row count non_group_m = prod(x_shape[1:eff]) must also be + divisible by 128 (because group_sizes[i] counts slices, not rows, and + actual rows per group = group_sizes[i] * non_group_m). + 3. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must + be 128-aligned -- this is a dynamic constraint assumed by the caller. + + Falls back to V1 when constraints are not met. V1 supports arbitrary shapes + but performs a D2H copy of group_sizes (not CUDA-graph safe). + """ + if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + return False + ndim = len(x_shape) + eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim + total_first_dim = math.prod(x_shape[:eff]) + if total_first_dim % 128 != 0: + return False + # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), + # non_group_m = K must also be 128-aligned. + if eff > 1: + non_group_m = math.prod(x_shape[1:eff]) + if non_group_m % 128 != 0: + return False + return True + @staticmethod def abstract( x_aval, @@ -1051,11 +1083,14 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING - if is_mxfp8: - # V2 path: 5th output is int64_workspace (n_groups * sizeof(int64_t) bytes as uint8) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2 path: 5th output is int64_workspace laid out as: + # [n_groups int64 group_sizes | n_groups+1 int64 offsets] + # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. + n_groups = group_sizes_aval.size fifth_out_aval = jax.core.ShapedArray( - shape=(group_sizes_aval.size * 8,), dtype=jnp.uint8 + shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8 ) else: # V1 path: 5th output is amax @@ -1095,11 +1130,13 @@ def outer_abstract(*args, **kwargs): colwise_scale_inv, fifth_out, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - # For MXFP8, the inner abstract returns int64_workspace as the 5th output. + # When V2 is used, the inner abstract returns int64_workspace as the 5th output. # The outer interface always presents amax (float32, n_groups) for a consistent API. scaling_mode = kwargs.get("scaling_mode") + x_aval = args[0] group_sizes_aval = args[2] - if ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING: + flatten_axis = kwargs.get("flatten_axis") + if GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis): fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out @@ -1126,9 +1163,11 @@ def lowering( assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 assert group_axis == 0 - is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING - if is_mxfp8: - # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. + # Requires total_first_dim % 128 == 0 (checked above) and all individual + # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( ctx, x, @@ -1137,6 +1176,8 @@ def lowering( q_layout=q_layout.value.value, flatten_axis=flatten_axis, ) + # V1: supports arbitrary shapes but not CUDA-graph safe (performs D2H copy of group_sizes). + # Used for non-MXFP8 scaling modes and for MXFP8 when total_first_dim % 128 != 0. return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1180,8 +1221,8 @@ def impl( group_axis=group_axis, scale_dtype=scale_dtype, ) - is_mxfp8 = ScalingMode(scaling_mode) == ScalingMode.MXFP8_1D_SCALING - if is_mxfp8: + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x.shape, flatten_axis) + if use_v2: # fifth is int64_workspace; return a dummy zero amax for interface compatibility updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) else: diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 07adf55577..495d4cc4bb 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -965,14 +965,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type // Note: This may break cudaGraph. cudaStreamSynchronize(stream); } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_rhs_ragged) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); - } + // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // if (!is_rhs_ragged) { + // NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + // ", got sum(group_sizes)=", sum_group_sizes); + // } else { + // NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + // ", got sum(group_sizes)=", sum_group_sizes); + // } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c5a766f7f2..fb083a958f 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -9,6 +9,7 @@ #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/gemm.h" #include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" @@ -494,5 +495,168 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, + Buffer_Type scale_unused, Buffer_Type group_sizes, + Result_Type rowwise_out, Result_Type colwise_out, + Result_Type rowwise_sinv, Result_Type colwise_sinv, + Result_Type int64_workspace, + JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { + (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity + auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(rowwise_out->element_type()); + auto sinv_dtype = convert_ffi_datatype_to_te_dtype(rowwise_sinv->element_type()); + + NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for GroupedQuantizeV2."); + NVTE_CHECK(sinv_dtype == DType::kFloat8E8M0, + "scale_inv must be E8M0 for MXFP8 grouped quantize."); + + auto input_dims = inputs.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); + size_t n_groups = group_sizes.dimensions()[0]; + + // Workspace layout (CUDA-graph safe, all device-side): + // int64_ptr[0 .. n_groups-1] : per-group ROW counts (int64) + // int64_ptr[n_groups .. 2*n_groups] : exclusive prefix-sum offsets (n_groups+1 values) + auto *int64_ptr = reinterpret_cast(int64_workspace->untyped_data()); + auto *offsets_ptr_out = int64_ptr + n_groups; // n_groups+1 values follow group_sizes + + // non_group_m handles multi-dim tensors (e.g., kernel shape G×K×N with flatten_axis=2): + // group_sizes[i] counts "slices" along the outermost group axis (e.g., 1 per expert), + // while the kernel expects actual ROW counts (e.g., K rows per expert). + // non_group_m = product(input_dims[1..flatten_axis)) converts slice→row count. + // For the lhs case (shape M×K, flatten_axis=1), non_group_m=1 (no-op). + int64_t non_group_m = + (flatten_axis > 1) ? product(input_dims, 1, static_cast(flatten_axis)) : 1; + + // Convert int32 group_sizes to int64 row counts on device (CUDA-graph safe, no D2H). + nvte_convert_int32_to_int64_with_multiplier( + reinterpret_cast(group_sizes.untyped_data()), int64_ptr, n_groups, + non_group_m, stream); + + // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, + static_cast(n), stream); + + NVTEShape data_shape{}; + data_shape.data[0] = m; + data_shape.data[1] = n; + data_shape.ndim = 2; + + NVTEShape sz_shape{}; + sz_shape.ndim = 1; + sz_shape.data[0] = n_groups; + + // Offsets tensor has n_groups+1 elements (exclusive prefix sums with sentinel). + NVTEShape offsets_shape{}; + offsets_shape.ndim = 1; + offsets_shape.data[0] = n_groups + 1; + + // Build input grouped tensor (plain float data, no quantization on the input side). + NVTEGroupedTensor in_grouped = + nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), + n_groups, data_shape); + { + NVTEBasicTensor in_data{reinterpret_cast(inputs.untyped_data()), + static_cast(in_dtype), data_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedRowwiseData, &in_data, sizeof(in_data)); + NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, + sz_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedFirstDims, &sz_tensor, sizeof(sz_tensor)); + NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), + NVTEDType::kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, + sizeof(offsets_tensor)); + } + + // Build output grouped tensor. + NVTEGroupedTensor out_grouped = + nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), + n_groups, data_shape); + + // Set group sizes and offsets on output tensor (same device pointers). + { + NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, + sz_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedFirstDims, &sz_tensor, + sizeof(sz_tensor)); + NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), + NVTEDType::kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, + sizeof(offsets_tensor)); + } + + // Rowwise output data + scale_inv. + if (is_quantize_rowwise(quantize_layout)) { + NVTEBasicTensor rw_data{reinterpret_cast(rowwise_out->untyped_data()), + static_cast(out_dtype), data_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseData, &rw_data, sizeof(rw_data)); + + auto sinv_dims = rowwise_sinv->dimensions(); + NVTEShape rw_sinv_shape{}; + rw_sinv_shape.ndim = 2; + rw_sinv_shape.data[0] = product(sinv_dims, 0, sinv_dims.size() - 1); + rw_sinv_shape.data[1] = sinv_dims.back(); + NVTEBasicTensor rw_sinv{reinterpret_cast(rowwise_sinv->untyped_data()), + static_cast(sinv_dtype), rw_sinv_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseScaleInv, &rw_sinv, + sizeof(rw_sinv)); + } + + // Colwise output data + scale_inv. + if (is_quantize_colwise(quantize_layout)) { + NVTEBasicTensor cw_data{reinterpret_cast(colwise_out->untyped_data()), + static_cast(out_dtype), data_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseData, &cw_data, + sizeof(cw_data)); + + auto cw_sinv_dims = colwise_sinv->dimensions(); + NVTEShape cw_sinv_shape{}; + cw_sinv_shape.ndim = 2; + cw_sinv_shape.data[0] = product(cw_sinv_dims, 0, cw_sinv_dims.size() - 1); + cw_sinv_shape.data[1] = cw_sinv_dims.back(); + NVTEBasicTensor cw_sinv{reinterpret_cast(colwise_sinv->untyped_data()), + static_cast(sinv_dtype), cw_sinv_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseScaleInv, &cw_sinv, + sizeof(cw_sinv)); + } + + // Zero-initialize scale_inv buffers (mirrors V1 behaviour for MXFP8). + size_t total_rowwise_sinv_size = + is_quantize_rowwise(quantize_layout) ? product(rowwise_sinv->dimensions()) : 0; + size_t total_colwise_sinv_size = + is_quantize_colwise(quantize_layout) ? product(colwise_sinv->dimensions()) : 0; + if (total_rowwise_sinv_size > 0) + nvte_memset(rowwise_sinv->untyped_data(), 0, total_rowwise_sinv_size, stream); + if (total_colwise_sinv_size > 0) + nvte_memset(colwise_sinv->untyped_data(), 0, total_colwise_sinv_size, stream); + + nvte_group_quantize(in_grouped, out_grouped, stream); + + nvte_destroy_grouped_tensor(in_grouped); + nvte_destroy_grouped_tensor(out_grouped); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeV2Handler, GroupedQuantizeV2FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // inputs + .Arg() // scale (unused, for input arity match) + .Arg() // group_sizes (int32) + .Ret() // rowwise_out + .Ret() // colwise_out + .Ret() // rowwise_sinv + .Ret() // colwise_sinv + .Ret() // int64_workspace + .Attr("q_layout") + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 31ce6e72e9..4f19f449bc 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,6 +16,9 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, +) from ..dense import dense, grouped_dense @@ -1446,7 +1449,11 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): def make_grouped_dense_cls(quantization_recipe): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: - raise ValueError("Ragged dot grouped GEMM does not support quantization yet") + allowed_grouped_gemm_recipes = [MXFP8BlockScaling] + assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( + f"Only the following quantization recipes are supported for grouped GEMM or `None` for BF16 without quantization: {allowed_grouped_gemm_recipes}. " + f"Got {type(quantization_recipe)}." + ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): del kwargs # Unused diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 316e4f3139..4511b63a64 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -412,6 +412,18 @@ def __init__( has_rht_applied=False, ) + @property + def group_sizes(self) -> jnp.ndarray: + """Per-group sizes along the group axis. + + When first_dims is set (ragged groups), returns first_dims. + When first_dims is None (equal-sized groups), returns an array of ones with + length equal to the number of groups. + """ + if self.first_dims is not None and self.first_dims.size > 0: + return self.first_dims + return jnp.ones((self.original_shape[self.group_axis],), dtype=jnp.int32) + def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" @@ -692,10 +704,8 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if ( - first_dims is not None - or last_dims is not None - or (original_shape is not None and group_axis is not None) + if first_dims is not None or last_dims is not None or ( + original_shape is not None and group_axis is not None ): assert ( original_shape is not None @@ -703,9 +713,7 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = ( - first_dims if first_dims is not None and first_dims.size > 0 else last_dims - ) + active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims if active_dims is not None: num_groups = active_dims.size else: From 611526ff545aadb532c714c0f713d0ab6a9ccd6d Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 00:00:02 -0700 Subject: [PATCH 17/23] mxfp8 grouped gemm Signed-off-by: Jeremy Berchtold --- tests/jax/test_custom_call_compute.py | 52 +++++++ transformer_engine/jax/cpp_extensions/gemm.py | 107 ++++++++++++-- .../jax/csrc/extensions/gemm.cpp | 132 ++++++++++++++++-- 3 files changed, 273 insertions(+), 18 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index db44621d80..1f1286a399 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -2074,6 +2074,58 @@ def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): ) assert prim_out.shape == ref_out.shape assert jnp.all(jnp.isfinite(prim_out)) + # Numerical check within FP8 tolerance + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, + ids=[f"v2_grad_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], + ) + def test_grouped_dense_grad_mxfp8_v2(self, input_shape): + """MXFP8 V2 grouped GEMM gradient test (fwd + dgrad + wgrad).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) + n_groups = input_shape[0] + fwd_dtype = jnp.float8_e4m3fn + bwd_dtype = jnp.float8_e4m3fn + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=n_groups, + ) + + contracting_dims = ((1,), (1,)) + + def _ref_sum(x, kernel, group_sizes): + lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(kernel, n_groups, axis=0) + out = jnp.concatenate( + [jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0 + ) + return jnp.sum(out) / jnp.sqrt(x.size) + + def _prim_sum(x, kernel, group_sizes): + out = grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias=None, + quantizer_set=quantizer_set, + ) + return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) + + ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) + prim_val, (prim_dx, prim_dk) = jit( + value_and_grad(_prim_sum, (0, 1)), static_argnums=() + )(lhs, rhs, group_sizes) + + assert_allclose(prim_val, ref_val, dtype=fwd_dtype) + assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype) + assert_allclose(prim_dk, ref_dk, dtype=bwd_dtype) class TestDebugInspectFFI: diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 1f69535b9b..e4fd9c99a6 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -379,6 +379,25 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) +def _swizzle_grouped_scale(scale_inv, scale_2d_shape, is_colwise): + """Swizzle a 1D grouped scale_inv buffer using full-tensor swizzle. + + The grouped scale_inv is 1D (worst-case padded). The meaningful prefix has size + equal to prod(scale_2d_shape). We reshape that prefix to 2D, swizzle it, and + write it back, leaving any trailing padding untouched. + """ + useful_size = math.prod(scale_2d_shape) + if useful_size == scale_inv.shape[0]: + # No trailing padding — reshape, swizzle, flatten. + return swizzled_scale(scale_inv.reshape(scale_2d_shape), 1, is_colwise).reshape( + scale_inv.shape + ) + # Split meaningful prefix from trailing padding, swizzle prefix only. + prefix = scale_inv[:useful_size].reshape(scale_2d_shape) + swizzled = swizzled_scale(prefix, 1, is_colwise).reshape((useful_size,)) + return jnp.concatenate([swizzled, scale_inv[useful_size:]]) + + def get_lhs_axis_boundary(lhs_cdims, is_transposed): """Get the axis boundary for the LHS operand.""" return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims) @@ -1626,9 +1645,11 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - # We also pad scale_inv swizzle buffers size for 256 bytes alignment. - workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + if not use_v2_ffi: + # V1 needs workspace for per-group swizzle output buffers. + # V2: scales are pre-swizzled in JAX, no extra workspace needed. + workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding return workspace_size @staticmethod @@ -2024,11 +2045,14 @@ def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, ) -> bool: """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" - # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy - # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay - # feature-compatible with the main branch. + # Use the cuda-graphable path for plain BF16 non-quantized inputs and MXFP8; fall back to + # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). # Bias can be supported in a kernel or in pure-JAX in the future. if not _v2_grouped_gemm_available: @@ -2039,7 +2063,42 @@ def _can_use_v2_grouped_gemm( if get_device_compute_capability(0) < 100: return False - return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + if has_bias: + return False + + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: + return True + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # V2 MXFP8 requires that the total first dimension of both operands (up to + # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. + # Individual group sizes must also be 128-aligned (dynamic constraint). + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) + if lhs_first_dim % 128 != 0: + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) + if rhs_first_dim % 128 != 0: + return False + # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both + # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group + # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. + # The quantize kernel, however, pads the colwise scale stride to + # ``ceil(last_dim / 128) * 128``, making per-group padded scale larger than + # ``K_blocks * last_dim`` when ``last_dim`` is not 128-aligned. This causes + # adjacent groups' scales to overlap in the flat buffer. Fall back to V1 (which + # swizzles per-group scales individually) when the condition is not met. + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) + if lhs_last_dim % 128 != 0: + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) + if rhs_last_dim % 128 != 0: + return False + return True + + return False def grouped_gemm( @@ -2278,7 +2337,39 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_v2_ffi = _can_use_v2_grouped_gemm( + scaling_mode, lhs_data.dtype, has_bias, + lhs_shape=lhs_data.shape, rhs_shape=rhs_data.shape, + lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, + ) + if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # Pre-swizzle full scale tensors in JAX (CUDA-graph safe). + # Grouped scale_inv is 1D (flat, worst-case padded). When all group sizes are + # multiples of 128 (V2 requirement), the per-group scales are contiguous with no + # inter-group padding gaps. We reshape the meaningful prefix to 2D, swizzle, and + # write it back into the original 1D buffer (extra trailing zeros stay untouched). + lhs_is_colwise = lhs_is_trans + rhs_is_colwise = not rhs_is_trans + lhs_scale_shape = scaling_mode.get_scale_shape( + lhs_data.shape, is_colwise=lhs_is_colwise, is_padded=True, + flatten_axis=lhs_axis_boundary, + ) + rhs_scale_shape = scaling_mode.get_scale_shape( + rhs_data.shape, is_colwise=rhs_is_colwise, is_padded=True, + flatten_axis=rhs_axis_boundary, + ) + # get_scale_shape may return a multi-dim shape (e.g. (8, 4, 128) for a 3D + # input), but _swizzle_grouped_scale needs a flat 2D shape (rows, cols) where + # cols = n_block_y (last dim) and rows = prod(all other dims). This correctly + # flattens the group/K-block axes into a single row dimension so the swizzle + # pattern operates on the full (K-blocks-across-groups × N-blocks) matrix. + lhs_n_block_y = lhs_scale_shape[-1] + rhs_n_block_y = rhs_scale_shape[-1] + lhs_scale_2d = (math.prod(lhs_scale_shape) // lhs_n_block_y, lhs_n_block_y) + rhs_scale_2d = (math.prod(rhs_scale_shape) // rhs_n_block_y, rhs_n_block_y) + lhs_scale_inv = _swizzle_grouped_scale(lhs_scale_inv, lhs_scale_2d, lhs_is_colwise) + rhs_scale_inv = _swizzle_grouped_scale(rhs_scale_inv, rhs_scale_2d, rhs_is_colwise) + if use_v2_ffi: additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index b26b261996..786fac5dcd 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -481,6 +481,8 @@ class JAXX_GroupedTensorWrapper { m_grouped_tensor(other.m_grouped_tensor), m_data_tensor(other.m_data_tensor), m_scale_inv_tensor(other.m_scale_inv_tensor), + m_colwise_data_tensor(other.m_colwise_data_tensor), + m_colwise_scale_inv_tensor(other.m_colwise_scale_inv_tensor), m_sizes_tensor(other.m_sizes_tensor), m_offsets_tensor(other.m_offsets_tensor) { other.m_grouped_tensor = nullptr; @@ -489,6 +491,8 @@ class JAXX_GroupedTensorWrapper { ~JAXX_GroupedTensorWrapper(); void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_columnwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_with_gemm_swizzled_scales(bool val); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. @@ -505,6 +509,8 @@ class JAXX_GroupedTensorWrapper { // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. NVTEBasicTensor m_data_tensor{}; NVTEBasicTensor m_scale_inv_tensor{}; + NVTEBasicTensor m_colwise_data_tensor{}; + NVTEBasicTensor m_colwise_scale_inv_tensor{}; NVTEBasicTensor m_sizes_tensor{}; NVTEBasicTensor m_offsets_tensor{}; @@ -556,6 +562,46 @@ void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, } } +void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_colwise_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseData, + &m_colwise_data_tensor, sizeof(m_colwise_data_tensor)); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM columnwise scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_colwise_scale_inv_tensor = + NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, + logical_scale_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, + sizeof(m_colwise_scale_inv_tensor)); + } +} + +void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { + auto v = static_cast(val); + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedWithGEMMSwizzledScales, &v, + sizeof(v)); +} + void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name) { @@ -619,11 +665,11 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// V2 variant: derives data shape from the XLA buffer directly, converts group_sizes +// V2 variant (NO_SCALING): derives data shape from the XLA buffer directly, converts group_sizes // int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. // int64_offset (in int64 elements) is updated on return to the next available slot so callers can // thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked -// before each slot is used. Only NO_SCALING is supported. +// before each slot is used. Only NO_SCALING is supported by this overload. JAXX_GroupedTensorWrapper make_grouped_tensor( Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, @@ -660,6 +706,56 @@ JAXX_GroupedTensorWrapper make_grouped_tensor( return wrapper; } +// V2 variant with scaling support (MXFP8 or NO_SCALING). Accepts scale_inv buffer and +// use_colwise flag to wire rowwise or columnwise data+scales for the grouped tensor. +// Pre-swizzled scales are indicated via set_with_gemm_swizzled_scales(true). +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &scale_inv, JAXX_Scaling_Mode scaling_mode, + bool use_colwise, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(scaling_mode, num_gemms, dataShape); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + if (is_mxfp8 && use_colwise) { + wrapper.set_columnwise(data, scale_inv); + } else if (is_mxfp8) { + wrapper.set_rowwise(data, scale_inv); + } else { + // NO_SCALING: no scale_inv needed + wrapper.set_rowwise(data, std::nullopt); + } + if (is_mxfp8) { + wrapper.set_with_gemm_swizzled_scales(true); + } + + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} + // Returns num_gemms from the first non-empty per-tensor group_sizes buffer, // falling back to the element count of alpha for the uniform-batch case. size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, @@ -700,8 +796,11 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty Result_Type int64_workspace, GroupedGemmV2Config config) { auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary] = config; - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Only non-quantized grouped GEMM is supported in current implementation."); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING, + "Only NO_SCALING and MXFP8_1D_SCALING are supported in the V2 grouped GEMM."); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims, alpha); @@ -731,12 +830,25 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); size_t int64_offset = 0; - auto rhs_tensor = - make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); - auto lhs_tensor = - make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); + + // For MXFP8: in JAX, rhs=cuBLAS_A, lhs=cuBLAS_B (swapped). + // Colwise is needed when the operand's contracting dim is NOT the last dim in its layout. + const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; + + auto rhs_tensor = is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, + rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, + lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); + // Output stays NO_SCALING auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); From c97b0b70548378da17f9657293d1871ea0d293fe Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 10:21:23 -0700 Subject: [PATCH 18/23] te_permutation NaN issue fix Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/permutation.py | 10 ++++++---- 1 file changed, 6 insertions(+), 4 deletions(-) diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 6a0a3229d9..2732c4acc5 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -497,8 +497,9 @@ def _token_combine_bwd_rule( hidden_size, ) # The backward kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized (NaN) values - replace with zeros. - inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + # Padded positions may contain uninitialized values (NaN, inf, or garbage). + # Replace any non-finite values with zeros. + inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) else: inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( output_grad, @@ -527,8 +528,9 @@ def _token_combine_bwd_rule( align_size=128, # Default, sizes already computed in forward ) # The permute kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized (NaN) values - replace with zeros. - inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + # Padded positions may contain uninitialized values (NaN, inf, or garbage). + # Replace any non-finite values with zeros. + inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) else: inp_grad, _ = permute_with_mask_map( output_grad, From 0b9a7637d7cf5f4bd73c78286ba330a1037c470b Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 10:22:22 -0700 Subject: [PATCH 19/23] Support GroupedDense quantization checkpointing Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/flax/module.py | 8 +++++--- 1 file changed, 5 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 4f19f449bc..dbd6fb1fef 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1361,7 +1361,7 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None): +def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None, quantization_checkpoint_name: Optional[str] = None): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1389,6 +1389,7 @@ def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, + quantization_checkpoint_name=quantization_checkpoint_name, fp8_recipe=quantization_recipe, n_groups=n_groups, ) @@ -1446,7 +1447,7 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") -def make_grouped_dense_cls(quantization_recipe): +def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Optional[str] = None): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: allowed_grouped_gemm_recipes = [MXFP8BlockScaling] @@ -1470,5 +1471,6 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot" + te_grouped_dot_general, quantization_recipe, "ragged_dot", + quantization_checkpoint_name=quantization_checkpoint_name, )() From 6b64cea01c9a31993d83c9a7bf9a0cd97f0bb024 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 10:23:01 -0700 Subject: [PATCH 20/23] Temporary commit to assert if V1 grouped quantize is used Signed-off-by: Jeremy Berchtold --- transformer_engine/jax/cpp_extensions/quantization.py | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 02639cfa2d..936b48cae0 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -1025,17 +1025,26 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): but performs a D2H copy of group_sizes (not CUDA-graph safe). """ if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + assert False, "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got scaling_mode {}".format( + scaling_mode + ) return False ndim = len(x_shape) eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim total_first_dim = math.prod(x_shape[:eff]) if total_first_dim % 128 != 0: + assert False, "V2 grouped quantize kernel requires total first logical dimension (product of x_shape up to flatten_axis) to be divisible by 128, but got shape {} and flatten_axis {} with total_first_dim {}".format( + x_shape, flatten_axis, total_first_dim + ) return False # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), # non_group_m = K must also be 128-aligned. if eff > 1: non_group_m = math.prod(x_shape[1:eff]) if non_group_m % 128 != 0: + assert False, "V2 grouped quantize kernel requires non-group dimension (product of x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors, but got shape {} and flatten_axis {} with non_group_m {}".format( + x_shape, flatten_axis, non_group_m + ) return False return True From 2dd69d4fa4a4ae5c88d9e79404951b0e528f52e4 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 12:08:05 -0700 Subject: [PATCH 21/23] Fix scale shapes for MXFP8 Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 46 ++++++++++++++++++- 1 file changed, 44 insertions(+), 2 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 936b48cae0..054bf79431 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -51,6 +51,36 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +def _build_scale_spec(x_spec, scale_shape, mesh): + """Build a PartitionSpec for the MXFP8 scale tensor compatible with its shape. + + The scale tensor has smaller dimensions than the data tensor (each dimension + divided by the MXFP8 block size). This function ensures that we only shard a + scale dimension by a mesh axis if scale_shape[i] is divisible by that axis's + size. If not, a ValueError is raised with a helpful diagnostic message. + """ + result = [] + for axis, scale_dim in zip(x_spec, scale_shape): + if axis is None: + result.append(None) + elif isinstance(axis, str): + axis_size = mesh.shape.get(axis, 1) + if scale_dim % axis_size == 0: + result.append(axis) + else: + raise ValueError( + f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " + f"by mesh axis '{axis}' of size {axis_size}: " + f"scale dim {scale_dim} is not divisible by {axis_size}. " + f"The data tensor's sharding is incompatible with the MXFP8 block " + f"size along this axis. Try reducing expert parallelism (EP) so that " + f"EP divides the scale dimension, or increase the tensor size." + ) + else: + result.append(None) # tuple axes: conservatively leave unsharded + return tuple(result) + + class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -446,7 +476,13 @@ def infer_sharding_from_operands( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( + arg_infos[0].shape, + is_padded=False, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) if q_layout.has_colwise: if ( @@ -528,7 +564,13 @@ def partition( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( + arg_infos[0].shape, + is_padded=False, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) if q_layout.has_colwise: if ( From 204b3260da8fdd44304992d993a9726d1a8dece6 Mon Sep 17 00:00:00 2001 From: Jeremy Berchtold Date: Sat, 14 Mar 2026 12:35:35 -0700 Subject: [PATCH 22/23] Fix MXFP8 scale sharding when FSDP+EP on same axis Signed-off-by: Jeremy Berchtold --- .../jax/cpp_extensions/quantization.py | 23 ++++++++++++++++--- 1 file changed, 20 insertions(+), 3 deletions(-) diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 054bf79431..29a0a047cf 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -56,8 +56,9 @@ def _build_scale_spec(x_spec, scale_shape, mesh): The scale tensor has smaller dimensions than the data tensor (each dimension divided by the MXFP8 block size). This function ensures that we only shard a - scale dimension by a mesh axis if scale_shape[i] is divisible by that axis's - size. If not, a ValueError is raised with a helpful diagnostic message. + scale dimension by a mesh axis (or tuple of axes) if scale_shape[i] is + divisible by the total axis size. If not, a ValueError is raised with a + helpful diagnostic message. """ result = [] for axis, scale_dim in zip(x_spec, scale_shape): @@ -76,8 +77,24 @@ def _build_scale_spec(x_spec, scale_shape, mesh): f"size along this axis. Try reducing expert parallelism (EP) so that " f"EP divides the scale dimension, or increase the tensor size." ) + elif isinstance(axis, (tuple, list)): + # Multi-axis sharding (e.g. ('fsdp', 'expert')): check total combined size. + total_size = 1 + for a in axis: + total_size *= mesh.shape.get(a, 1) + if scale_dim % total_size == 0: + result.append(axis) + else: + raise ValueError( + f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " + f"by mesh axes {tuple(axis)} of combined size {total_size}: " + f"scale dim {scale_dim} is not divisible by {total_size}. " + f"The data tensor's sharding is incompatible with the MXFP8 block " + f"size along this axis. Try reducing parallelism or increasing the " + f"tensor size." + ) else: - result.append(None) # tuple axes: conservatively leave unsharded + result.append(None) return tuple(result) From 5fb585f907487034fae73629c9913c1b19bbd534 Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Sat, 14 Mar 2026 19:36:48 +0000 Subject: [PATCH 23/23] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/jax/test_custom_call_compute.py | 8 ++-- .../common/gemm/cublaslt_grouped_gemm.cu | 11 +++--- .../common/include/transformer_engine/gemm.h | 5 +-- transformer_engine/jax/cpp_extensions/gemm.py | 18 ++++++--- .../jax/cpp_extensions/quantization.py | 37 +++++++++++-------- .../jax/csrc/extensions/gemm.cpp | 29 ++++++++------- .../jax/csrc/extensions/quantization.cpp | 23 +++++------- transformer_engine/jax/flax/module.py | 16 ++++++-- transformer_engine/jax/quantize/tensor.py | 10 +++-- 9 files changed, 90 insertions(+), 67 deletions(-) diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 1f1286a399..0a76794003 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -1964,7 +1964,7 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): # (n_groups, m, n, k) # lhs total_rows = m * 128; rhs total_rows = n_groups * k (8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned) - (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) + (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) ] @@ -2119,9 +2119,9 @@ def _prim_sum(x, kernel, group_sizes): return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) - prim_val, (prim_dx, prim_dk) = jit( - value_and_grad(_prim_sum, (0, 1)), static_argnums=() - )(lhs, rhs, group_sizes) + prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())( + lhs, rhs, group_sizes + ) assert_allclose(prim_val, ref_val, dtype=fwd_dtype) assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype) diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index c57a662073..5dd1fe9c06 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -823,7 +823,7 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, // Like convert_int32_to_int64_kernel but scales each element by multiplier. // Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. __global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, - size_t n, int64_t multiplier) { + size_t n, int64_t multiplier) { size_t idx = blockIdx.x * blockDim.x + threadIdx.x; if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; } @@ -850,22 +850,21 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud } void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, - int64_t multiplier, cudaStream_t stream) { + int64_t multiplier, cudaStream_t stream) { NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); if (n == 0) return; const int threads = 256; const int blocks = static_cast((n + threads - 1) / threads); convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, - multiplier); + multiplier); NVTE_CHECK_CUDA(cudaGetLastError()); } void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, - size_t n_groups, int64_t last_dim, - cudaStream_t stream) { + size_t n_groups, int64_t last_dim, cudaStream_t stream) { NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); // Always write at least offsets[0]=0 (needed even for n_groups==0). compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, - last_dim); + last_dim); NVTE_CHECK_CUDA(cudaGetLastError()); } diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index dc1a104abb..5ee15613b0 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -368,7 +368,7 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud * \param[in] stream CUDA stream. */ void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, - int64_t multiplier, cudaStream_t stream); + int64_t multiplier, cudaStream_t stream); /*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. * @@ -383,8 +383,7 @@ void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *ds * \param[in] stream CUDA stream. */ void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, - size_t n_groups, int64_t last_dim, - cudaStream_t stream); + size_t n_groups, int64_t last_dim, cudaStream_t stream); void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index e4fd9c99a6..1029600389 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -2338,9 +2338,13 @@ def grouped_gemm( ) use_v2_ffi = _can_use_v2_grouped_gemm( - scaling_mode, lhs_data.dtype, has_bias, - lhs_shape=lhs_data.shape, rhs_shape=rhs_data.shape, - lhs_axis_boundary=lhs_axis_boundary, rhs_axis_boundary=rhs_axis_boundary, + scaling_mode, + lhs_data.dtype, + has_bias, + lhs_shape=lhs_data.shape, + rhs_shape=rhs_data.shape, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: # Pre-swizzle full scale tensors in JAX (CUDA-graph safe). @@ -2351,11 +2355,15 @@ def grouped_gemm( lhs_is_colwise = lhs_is_trans rhs_is_colwise = not rhs_is_trans lhs_scale_shape = scaling_mode.get_scale_shape( - lhs_data.shape, is_colwise=lhs_is_colwise, is_padded=True, + lhs_data.shape, + is_colwise=lhs_is_colwise, + is_padded=True, flatten_axis=lhs_axis_boundary, ) rhs_scale_shape = scaling_mode.get_scale_shape( - rhs_data.shape, is_colwise=rhs_is_colwise, is_padded=True, + rhs_data.shape, + is_colwise=rhs_is_colwise, + is_padded=True, flatten_axis=rhs_axis_boundary, ) # get_scale_shape may return a multi-dim shape (e.g. (8, 4, 128) for a 3D diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index 29a0a047cf..a2da6b8830 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -73,9 +73,9 @@ def _build_scale_spec(x_spec, scale_shape, mesh): f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " f"by mesh axis '{axis}' of size {axis_size}: " f"scale dim {scale_dim} is not divisible by {axis_size}. " - f"The data tensor's sharding is incompatible with the MXFP8 block " - f"size along this axis. Try reducing expert parallelism (EP) so that " - f"EP divides the scale dimension, or increase the tensor size." + "The data tensor's sharding is incompatible with the MXFP8 block " + "size along this axis. Try reducing expert parallelism (EP) so that " + "EP divides the scale dimension, or increase the tensor size." ) elif isinstance(axis, (tuple, list)): # Multi-axis sharding (e.g. ('fsdp', 'expert')): check total combined size. @@ -89,9 +89,9 @@ def _build_scale_spec(x_spec, scale_shape, mesh): f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " f"by mesh axes {tuple(axis)} of combined size {total_size}: " f"scale dim {scale_dim} is not divisible by {total_size}. " - f"The data tensor's sharding is incompatible with the MXFP8 block " - f"size along this axis. Try reducing parallelism or increasing the " - f"tensor size." + "The data tensor's sharding is incompatible with the MXFP8 block " + "size along this axis. Try reducing parallelism or increasing the " + "tensor size." ) else: result.append(None) @@ -1084,16 +1084,21 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): but performs a D2H copy of group_sizes (not CUDA-graph safe). """ if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: - assert False, "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got scaling_mode {}".format( - scaling_mode + assert False, ( + "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" + " scaling_mode {}".format(scaling_mode) ) return False ndim = len(x_shape) eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim total_first_dim = math.prod(x_shape[:eff]) if total_first_dim % 128 != 0: - assert False, "V2 grouped quantize kernel requires total first logical dimension (product of x_shape up to flatten_axis) to be divisible by 128, but got shape {} and flatten_axis {} with total_first_dim {}".format( - x_shape, flatten_axis, total_first_dim + assert False, ( + "V2 grouped quantize kernel requires total first logical dimension (product of" + " x_shape up to flatten_axis) to be divisible by 128, but got shape {} and" + " flatten_axis {} with total_first_dim {}".format( + x_shape, flatten_axis, total_first_dim + ) ) return False # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), @@ -1101,8 +1106,12 @@ def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): if eff > 1: non_group_m = math.prod(x_shape[1:eff]) if non_group_m % 128 != 0: - assert False, "V2 grouped quantize kernel requires non-group dimension (product of x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors, but got shape {} and flatten_axis {} with non_group_m {}".format( - x_shape, flatten_axis, non_group_m + assert False, ( + "V2 grouped quantize kernel requires non-group dimension (product of" + " x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors," + " but got shape {} and flatten_axis {} with non_group_m {}".format( + x_shape, flatten_axis, non_group_m + ) ) return False return True @@ -1157,9 +1166,7 @@ def abstract( # [n_groups int64 group_sizes | n_groups+1 int64 offsets] # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. n_groups = group_sizes_aval.size - fifth_out_aval = jax.core.ShapedArray( - shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8 - ) + fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) else: # V1 path: 5th output is amax fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 786fac5dcd..45625120fd 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -591,8 +591,7 @@ void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, logical_scale_shape}; nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, - &m_colwise_scale_inv_tensor, - sizeof(m_colwise_scale_inv_tensor)); + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); } } @@ -836,18 +835,20 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; - auto rhs_tensor = is_mxfp8 - ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, - rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary) - : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, rhs_axis_boundary); - auto lhs_tensor = is_mxfp8 - ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, - lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary) - : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, - int64_offset, num_gemms, stream, lhs_axis_boundary); + auto rhs_tensor = + is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, + rhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, rhs_axis_boundary) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, lhs_first_dims, + lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, lhs_axis_boundary) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); // Output stays NO_SCALING auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, stream); diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index fb083a958f..06f5906edf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -495,11 +495,10 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); -Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, - Buffer_Type scale_unused, Buffer_Type group_sizes, - Result_Type rowwise_out, Result_Type colwise_out, - Result_Type rowwise_sinv, Result_Type colwise_sinv, - Result_Type int64_workspace, +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, + Buffer_Type group_sizes, Result_Type rowwise_out, + Result_Type colwise_out, Result_Type rowwise_sinv, + Result_Type colwise_sinv, Result_Type int64_workspace, JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); @@ -539,8 +538,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, non_group_m, stream); // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). - nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, - static_cast(n), stream); + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, static_cast(n), + stream); NVTEShape data_shape{}; data_shape.data[0] = m; @@ -557,9 +556,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, offsets_shape.data[0] = n_groups + 1; // Build input grouped tensor (plain float data, no quantization on the input side). - NVTEGroupedTensor in_grouped = - nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), - n_groups, data_shape); + NVTEGroupedTensor in_grouped = nvte_create_grouped_tensor( + get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), n_groups, data_shape); { NVTEBasicTensor in_data{reinterpret_cast(inputs.untyped_data()), static_cast(in_dtype), data_shape}; @@ -574,9 +572,8 @@ Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, } // Build output grouped tensor. - NVTEGroupedTensor out_grouped = - nvte_create_grouped_tensor(get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), - n_groups, data_shape); + NVTEGroupedTensor out_grouped = nvte_create_grouped_tensor( + get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), n_groups, data_shape); // Set group sizes and offsets on output tensor (same device pointers). { diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index dbd6fb1fef..17c9a242f0 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -1361,7 +1361,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None, quantization_checkpoint_name: Optional[str] = None): +def wrap_function_in_te_state_module( + f, + quantization_recipe, + name: Optional[str] = None, + quantization_checkpoint_name: Optional[str] = None, +): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1452,8 +1457,9 @@ def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Op if quantization_recipe is not None: allowed_grouped_gemm_recipes = [MXFP8BlockScaling] assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( - f"Only the following quantization recipes are supported for grouped GEMM or `None` for BF16 without quantization: {allowed_grouped_gemm_recipes}. " - f"Got {type(quantization_recipe)}." + "Only the following quantization recipes are supported for grouped GEMM or `None` for" + f" BF16 without quantization: {allowed_grouped_gemm_recipes}. Got" + f" {type(quantization_recipe)}." ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): @@ -1471,6 +1477,8 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot", + te_grouped_dot_general, + quantization_recipe, + "ragged_dot", quantization_checkpoint_name=quantization_checkpoint_name, )() diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index 4511b63a64..4b604502b0 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -704,8 +704,10 @@ def create_1x( dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if first_dims is not None or last_dims is not None or ( - original_shape is not None and group_axis is not None + if ( + first_dims is not None + or last_dims is not None + or (original_shape is not None and group_axis is not None) ): assert ( original_shape is not None @@ -713,7 +715,9 @@ def create_1x( flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) # Determine num_groups from whichever dims array is provided, or from original_shape - active_dims = first_dims if first_dims is not None and first_dims.size > 0 else last_dims + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) if active_dims is not None: num_groups = active_dims.size else: