diff --git a/tests/jax/test_custom_call_compute.py b/tests/jax/test_custom_call_compute.py index 613aefc178..0a76794003 100644 --- a/tests/jax/test_custom_call_compute.py +++ b/tests/jax/test_custom_call_compute.py @@ -36,6 +36,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, QuantizerFactory, QuantizeLayout, @@ -1736,7 +1737,9 @@ def _ref_grouped_dense(self, lhs, rhs, bias, group_sizes, contracting_dims): ref_out.append(jnp.squeeze(out_i)) return ref_out - def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", with_bias=False): + def _generate_grouped_dense_input( + self, dtype, input_shape, data_layout="NN", with_bias=False, group_size_multiplier=32 + ): key = jax.random.PRNGKey(0) subkeys = jax.random.split(key, 4) n_groups, m, n, k = input_shape @@ -1749,9 +1752,12 @@ def _generate_grouped_dense_input(self, dtype, input_shape, data_layout="NN", wi group_sizes = group_sizes.at[1].set(0) assert group_sizes.sum() == m - # *32 to make sure that input shape works for MXFP8 - group_sizes = group_sizes * 32 - m = m * 32 + # Scale group sizes by the multiplier. + # Use group_size_multiplier=128 for MXFP8 V2 tests so that each group's row count + # is divisible by 128, satisfying the V2 kernel's per-group alignment requirement. + # Use group_size_multiplier=32 for V1 tests or non-MXFP8 tests. + group_sizes = group_sizes * group_size_multiplier + m = m * group_size_multiplier lhs_shape = (m if data_layout[0] == "N" else k, k if data_layout[0] == "N" else m) rhs_shape = (n_groups, k if data_layout[1] == "N" else n, n if data_layout[1] == "N" else k) @@ -1787,13 +1793,18 @@ def test_grouped_gemm_fp16(self, dtype, input_shape, layout): ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) # jitting grouped_gemm + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) prim_out = jax.jit( tex.grouped_gemm, static_argnames=("contracting_dims", "use_async_d2h_group_sizes") )( - lhs, - rhs, - group_sizes, - contracting_dims, + lhs_tensor, + rhs_tensor, + contracting_dims=contracting_dims, use_async_d2h_group_sizes=True, ) @@ -1820,13 +1831,24 @@ def test_grouped_gemm_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape, layout quantizer.q_dtype = bwd_dtype out_dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING lhs, rhs, group_sizes, contracting_dims, _ = self._generate_grouped_dense_input( - out_dtype, input_shape, layout + out_dtype, input_shape, layout, group_size_multiplier=128 if is_mxfp8 else 32 ) ref_out = self._ref_grouped_dense(lhs, rhs, None, group_sizes, contracting_dims) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( - lhs, rhs, group_sizes, contracting_dims, quantizer_set=quantizer_set + lhs_tensor, + rhs_tensor, + contracting_dims=contracting_dims, + quantizer_set=quantizer_set, ) allclose_dtype = jnp.float8_e4m3fn @@ -1886,10 +1908,13 @@ def test_grouped_dense_grad_fp16(self, dtype, input_shape): def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): fwd_dtype, bwd_dtype = fwd_bwd_dtype dtype = jnp.bfloat16 + # MXFP8 V2 kernel requires each group's row count to be divisible by 128. + is_mxfp8 = scaling_mode == ScalingMode.MXFP8_1D_SCALING x, kernel, group_sizes, contracting_dims, bias = self._generate_grouped_dense_input( dtype, input_shape, with_bias=True, + group_size_multiplier=128 if is_mxfp8 else 32, ) quantizer_set = QuantizerFactory.create_set( @@ -1923,6 +1948,186 @@ def test_grouped_dense_grad_fp8(self, fwd_bwd_dtype, scaling_mode, input_shape): assert_allclose(prim_dbias, ref_dbias, dtype=dtype) +# MXFP8 V1 shapes: lhs total_rows = m * 32 and rhs total_rows = n_groups * k are +# NOT divisible by 128, forcing the V1 (non-CUDA-graph-safe) kernel. +GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES = [ + # (n_groups, m, n, k) + # lhs total_rows = m * 32; rhs total_rows = n_groups * k + (5, 6, 128, 64), # lhs: 6*32=192 (not 128-aligned); rhs: 5*64=320 (not 128-aligned) +] + +# MXFP8 V2 shapes: lhs total_rows = m * 128 and rhs total_rows = n_groups * k are +# divisible by 128, allowing the V2 (CUDA-graph-safe) kernel to be used. +# These shapes must be paired with group_size_multiplier=128 so that each group's +# row count is also divisible by 128 (the V2 per-group alignment requirement). +GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES = [ + # (n_groups, m, n, k) + # lhs total_rows = m * 128; rhs total_rows = n_groups * k + (8, 8, 128, 128), # lhs: 8*128=1024 (128-aligned); rhs: 8*128=1024 (128-aligned) + (4, 4, 64, 256), # lhs: 4*128=512 (128-aligned); rhs: 4*256=1024 (128-aligned) +] + + +@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason) +class TestGroupedDenseMXFP8KernelSelection: + """Tests that explicitly verify V1 and V2 MXFP8 grouped quantize kernel selection. + + V2 is the CUDA-graph-safe kernel and requires: + - total_first_dim (= product of input shape up to flatten_axis) % 128 == 0 + - each individual group_size % 128 == 0 (enforced by the kernel at runtime) + V1 is the fallback that supports arbitrary shapes but performs a D2H copy of + group_sizes (not CUDA-graph safe). + """ + + def _generate_mxfp8_input(self, input_shape, group_size_multiplier): + """Generate inputs with the given group_size_multiplier for MXFP8 tests.""" + key = jax.random.PRNGKey(42) + subkeys = jax.random.split(key, 3) + n_groups, m, n, k = input_shape + + group_sizes = jnp.sort(jax.random.randint(subkeys[0], (n_groups - 1,), 0, m)) + group_sizes = jnp.concatenate([jnp.array([0]), group_sizes, jnp.array([m])]) + group_sizes = jnp.diff(group_sizes) + group_sizes = group_sizes.at[0].set(group_sizes[0] + group_sizes[1]) + group_sizes = group_sizes.at[1].set(0) + group_sizes = group_sizes * group_size_multiplier + m_total = m * group_size_multiplier + + lhs = jax.random.uniform(subkeys[1], (m_total, k), dtype=jnp.bfloat16) + rhs = jax.random.uniform(subkeys[2], (n_groups, k, n), dtype=jnp.bfloat16) + return lhs, rhs, group_sizes + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES, + ids=[f"v1_{s}" for s in GROUPED_DENSE_MXFP8_V1_INPUT_SHAPES], + ) + def test_grouped_gemm_mxfp8_v1_shapes(self, input_shape): + """MXFP8 grouped GEMM with V1-only shapes (total_first_dim not 128-aligned).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=32) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=input_shape[0], + ) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) + # Reference: unquantized grouped GEMM + n_groups = input_shape[0] + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + # Check output has correct shape and dtype; numerical precision is expected to be lower + # due to FP8 quantization but the result should be finite. + assert prim_out.shape == ref_out.shape + assert jnp.all(jnp.isfinite(prim_out)) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, + ids=[f"v2_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], + ) + def test_grouped_gemm_mxfp8_v2_shapes(self, input_shape): + """MXFP8 grouped GEMM with V2-eligible shapes (total_first_dim 128-aligned, + group_sizes also 128-aligned).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=jnp.float8_e4m3fn, + bwd_dtype=jnp.float8_e4m3fn, + is_2x2x=False, + n_groups=input_shape[0], + ) + lhs_tensor = GroupedNoScaleTensor( + data=lhs, first_dims=group_sizes, last_dims=None, group_axis=0, original_shape=lhs.shape + ) + rhs_tensor = GroupedNoScaleTensor( + data=rhs, first_dims=None, last_dims=None, group_axis=0, original_shape=rhs.shape + ) + n_groups = input_shape[0] + lhs_splits = jnp.split(lhs, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(rhs, n_groups, axis=0) + ref_out = jnp.concatenate( + [jnp.squeeze(lhs_i @ rhs_i, axis=0) for lhs_i, rhs_i in zip(lhs_splits, rhs_splits)], + axis=0, + ) + prim_out = jax.jit(tex.grouped_gemm, static_argnames=("contracting_dims",))( + lhs_tensor, + rhs_tensor, + contracting_dims=((1,), (1,)), + quantizer_set=quantizer_set, + ) + assert prim_out.shape == ref_out.shape + assert jnp.all(jnp.isfinite(prim_out)) + # Numerical check within FP8 tolerance + assert_allclose(prim_out, ref_out, dtype=jnp.float8_e4m3fn) + + @pytest.mark.parametrize( + "input_shape", + GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES, + ids=[f"v2_grad_{s}" for s in GROUPED_DENSE_MXFP8_V2_INPUT_SHAPES], + ) + def test_grouped_dense_grad_mxfp8_v2(self, input_shape): + """MXFP8 V2 grouped GEMM gradient test (fwd + dgrad + wgrad).""" + lhs, rhs, group_sizes = self._generate_mxfp8_input(input_shape, group_size_multiplier=128) + n_groups = input_shape[0] + fwd_dtype = jnp.float8_e4m3fn + bwd_dtype = jnp.float8_e4m3fn + + quantizer_set = QuantizerFactory.create_set( + scaling_mode=ScalingMode.MXFP8_1D_SCALING, + fwd_dtype=fwd_dtype, + bwd_dtype=bwd_dtype, + is_2x2x=True, + n_groups=n_groups, + ) + + contracting_dims = ((1,), (1,)) + + def _ref_sum(x, kernel, group_sizes): + lhs_splits = jnp.split(x, jnp.cumulative_sum(group_sizes)[:-1], axis=0) + rhs_splits = jnp.split(kernel, n_groups, axis=0) + out = jnp.concatenate( + [jnp.squeeze(li @ ri, axis=0) for li, ri in zip(lhs_splits, rhs_splits)], axis=0 + ) + return jnp.sum(out) / jnp.sqrt(x.size) + + def _prim_sum(x, kernel, group_sizes): + out = grouped_dense( + x, + kernel, + group_sizes, + contracting_dims, + bias=None, + quantizer_set=quantizer_set, + ) + return jnp.sum(jnp.asarray(out)) / jnp.sqrt(x.size) + + ref_val, (ref_dx, ref_dk) = value_and_grad(_ref_sum, (0, 1))(lhs, rhs, group_sizes) + prim_val, (prim_dx, prim_dk) = jit(value_and_grad(_prim_sum, (0, 1)), static_argnums=())( + lhs, rhs, group_sizes + ) + + assert_allclose(prim_val, ref_val, dtype=fwd_dtype) + assert_allclose(prim_dx, ref_dx, dtype=bwd_dtype) + assert_allclose(prim_dk, ref_dk, dtype=bwd_dtype) + + class TestDebugInspectFFI: @pytest_parametrize_wrapper("shape", [(256, 128)]) diff --git a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh index 129d6724ac..d7eaf028e0 100644 --- a/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh +++ b/transformer_engine/common/cast/mxfp8/group_quantize_mxfp8.cuh @@ -192,6 +192,10 @@ __global__ void update_tma_descriptors( const size_t offset_elts = offsets_ptr[tensor_id]; if (leading_thread && (tensor_id < num_tensors)) { + // Zero-sized groups: skip TMA descriptor update. The main kernel already returns + // early for rows==0 or cols==0, but creating a TMA descriptor with a zero dimension + // is invalid and causes CUDA_ERROR_ILLEGAL_ADDRESS. + if (rows == 0 || cols == 0) return; { const uintptr_t global_data_ptr = reinterpret_cast(input_data_ptr + offset_elts); modify_base_tensor_map(base_tensor_map_input, &g_tensor_maps_input[tensor_id], diff --git a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu index 7069debc56..5dd1fe9c06 100644 --- a/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu +++ b/transformer_engine/common/gemm/cublaslt_grouped_gemm.cu @@ -820,6 +820,24 @@ __global__ void convert_int32_to_int64_kernel(const int32_t *src, int64_t *dst, if (idx < n) dst[idx] = static_cast(src[idx]); } +// Like convert_int32_to_int64_kernel but scales each element by multiplier. +// Used to convert per-expert slice counts to per-expert row counts for multi-dim tensors. +__global__ void convert_int32_to_int64_with_multiplier_kernel(const int32_t *src, int64_t *dst, + size_t n, int64_t multiplier) { + size_t idx = blockIdx.x * blockDim.x + threadIdx.x; + if (idx < n) dst[idx] = static_cast(src[idx]) * multiplier; +} + +// Computes exclusive prefix sums: offsets[0]=0, offsets[i]=sum(first_dims[0..i-1]*last_dim). +// Produces n_groups+1 values. Single-threaded sequential scan; n_groups is typically small. +__global__ void compute_grouped_tensor_offsets_kernel(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim) { + offsets[0] = 0; + for (size_t i = 0; i < n_groups; i++) { + offsets[i + 1] = offsets[i] + first_dims[i] * last_dim; + } +} + } // namespace void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream) { @@ -830,3 +848,23 @@ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cud convert_int32_to_int64_kernel<<>>(src, dst, n); NVTE_CHECK_CUDA(cudaGetLastError()); } + +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream) { + NVTE_API_CALL(nvte_convert_int32_to_int64_with_multiplier); + if (n == 0) return; + const int threads = 256; + const int blocks = static_cast((n + threads - 1) / threads); + convert_int32_to_int64_with_multiplier_kernel<<>>(src, dst, n, + multiplier); + NVTE_CHECK_CUDA(cudaGetLastError()); +} + +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, cudaStream_t stream) { + NVTE_API_CALL(nvte_compute_grouped_tensor_offsets); + // Always write at least offsets[0]=0 (needed even for n_groups==0). + compute_grouped_tensor_offsets_kernel<<<1, 1, 0, stream>>>(first_dims, offsets, n_groups, + last_dim); + NVTE_CHECK_CUDA(cudaGetLastError()); +} diff --git a/transformer_engine/common/include/transformer_engine/gemm.h b/transformer_engine/common/include/transformer_engine/gemm.h index 35d327f085..5ee15613b0 100644 --- a/transformer_engine/common/include/transformer_engine/gemm.h +++ b/transformer_engine/common/include/transformer_engine/gemm.h @@ -356,6 +356,35 @@ size_t nvte_get_grouped_gemm_setup_workspace_size(size_t num_tensors); */ void nvte_convert_int32_to_int64(const int32_t *src, int64_t *dst, size_t n, cudaStream_t stream); +/*! \brief Convert int32 array to int64 while scaling each element by a multiplier. + * + * Computes dst[i] = (int64_t)src[i] * multiplier for each i in [0, n). + * CUDA-graph safe (no host-device synchronization). + * + * \param[in] src Device pointer to source int32 array. + * \param[out] dst Device pointer to destination int64 array. + * \param[in] n Number of elements. + * \param[in] multiplier Scale factor applied to each element. + * \param[in] stream CUDA stream. + */ +void nvte_convert_int32_to_int64_with_multiplier(const int32_t *src, int64_t *dst, size_t n, + int64_t multiplier, cudaStream_t stream); + +/*! \brief Compute exclusive prefix-sum offsets from per-group first-dimension sizes. + * + * Writes n_groups+1 values to offsets: offsets[0]=0, + * offsets[i] = sum(first_dims[0..i-1] * last_dim) for i in [1, n_groups]. + * This is CUDA-graph safe (no host-device synchronization). + * + * \param[in] first_dims Device pointer to int64 array of length n_groups. + * \param[out] offsets Device pointer to int64 array of length n_groups+1. + * \param[in] n_groups Number of groups. + * \param[in] last_dim Common last dimension (number of columns). + * \param[in] stream CUDA stream. + */ +void nvte_compute_grouped_tensor_offsets(const int64_t *first_dims, int64_t *offsets, + size_t n_groups, int64_t last_dim, cudaStream_t stream); + void nvte_grouped_gemm(const NVTEGroupedTensor A, int transa, const NVTEGroupedTensor B, int transb, const NVTEGroupedTensor C, NVTEGroupedTensor D, const NVTETensor alpha, const NVTETensor beta, NVTETensor workspace_setup, diff --git a/transformer_engine/jax/cpp_extensions/gemm.py b/transformer_engine/jax/cpp_extensions/gemm.py index 515f02af6e..1029600389 100644 --- a/transformer_engine/jax/cpp_extensions/gemm.py +++ b/transformer_engine/jax/cpp_extensions/gemm.py @@ -37,6 +37,7 @@ ScaledTensor1x, ScaledTensor2x, GroupedScaledTensor1x, + GroupedNoScaleTensor, ScalingMode, Quantizer, GroupedQuantizer, @@ -378,6 +379,25 @@ def swizzled_scale(scale_inv, flatten_axis, is_colwise): return swizzled.reshape(original_shape) +def _swizzle_grouped_scale(scale_inv, scale_2d_shape, is_colwise): + """Swizzle a 1D grouped scale_inv buffer using full-tensor swizzle. + + The grouped scale_inv is 1D (worst-case padded). The meaningful prefix has size + equal to prod(scale_2d_shape). We reshape that prefix to 2D, swizzle it, and + write it back, leaving any trailing padding untouched. + """ + useful_size = math.prod(scale_2d_shape) + if useful_size == scale_inv.shape[0]: + # No trailing padding — reshape, swizzle, flatten. + return swizzled_scale(scale_inv.reshape(scale_2d_shape), 1, is_colwise).reshape( + scale_inv.shape + ) + # Split meaningful prefix from trailing padding, swizzle prefix only. + prefix = scale_inv[:useful_size].reshape(scale_2d_shape) + swizzled = swizzled_scale(prefix, 1, is_colwise).reshape((useful_size,)) + return jnp.concatenate([swizzled, scale_inv[useful_size:]]) + + def get_lhs_axis_boundary(lhs_cdims, is_transposed): """Get the axis boundary for the LHS operand.""" return max(lhs_cdims) + 1 if is_transposed else min(lhs_cdims) @@ -1392,17 +1412,47 @@ def impl( register_primitive(GroupedGemmCopySizesPrimitive) +def _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups: int, +) -> None: + """Assert that all non-empty *_dims arrays have exactly num_groups elements. + + rhs_first_dims / rhs_last_dims describe the ragged contracting K dimension. + K totals need not fill the entire buffer (padding is allowed), so only the + array length is checked, not the per-group sum. + """ + for name, aval in [ + ("lhs_first_dims", lhs_first_dims_aval), + ("lhs_last_dims", lhs_last_dims_aval), + ("out_first_dims", out_first_dims_aval), + ("out_last_dims", out_last_dims_aval), + ("rhs_first_dims", rhs_first_dims_aval), + ("rhs_last_dims", rhs_last_dims_aval), + ]: + if aval.size > 0: + assert ( + aval.size == num_groups + ), f"grouped GEMM {name} has size {aval.size}, expected num_groups={num_groups}" + + class GroupedGemmPrimitive(BasePrimitive): """ Primitive for grouped GEMM using nvte_multi_tensor_gemm (supports all scaling modes) or nvte_grouped_gemm (supporting BF16). """ - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, group_offset, unused_placeholder name = "te_grouped_gemm_ffi" - # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, group_sizes, alpha, beta + # args = lhs_data, lhs_scale_inv, rhs_data, rhs_scale_inv, bias, + # lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, + # out_first_dims, out_last_dims, alpha, beta name_graph_safe = "te_grouped_gemm_v2_ffi" multiple_results = True - impl_static_args = (8, 9, 10, 11, 12, 13, 14, 15, 16, 17, 18) + impl_static_args = (13, 14, 15, 16, 17, 18, 19, 20, 21, 22) inner_primitive = None outer_primitive = None @@ -1413,53 +1463,103 @@ def abstract( rhs_data_aval, rhs_scale_inv_aval, bias_aval, - group_sizes_aval, + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, *additional_args, # group_offset_aval, unused_placeholder OR alpha_aval, beta_aval - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): """ Grouped GEMM operation. Args: - lhs_data: Left-hand side input matrix data, 1D flattened array + lhs_data: Left-hand side input matrix data, N-D array lhs_scale_inv: Left-hand side input scale_inv matrix, 1D flattened array - rhs_data: Right-hand side input matrix data, 1D flattened array + rhs_data: Right-hand side input matrix data, N-D array rhs_scale_inv: Right-hand side input scale_inv matrix, 1D flattened array bias: Bias matrix of shape (G, N) - group_sizes: 1D array containing the sizes of each group + lhs_first_dims: (G,) int32 if lhs first-dim is ragged, else empty (0,) sentinel + rhs_first_dims: (G,) int32 if rhs first-dim is ragged (wgrad), else empty (0,) sentinel + out_first_dims: (G,) int32 if output first-dim is ragged, else empty (0,) sentinel additional_args: Either * group_offsets: 1D array containing offsets for each group (not yet implemented) OR * alpha: 1D array of shape (G,) containing alpha values for each group * beta: 1D array of shape (G,) containing beta values for each group - M: Number of rows in the output matrix - N: Number of columns in the output matrix - K: Number of columns in the left-hand side matrix lhs_is_trans: Boolean indicating if the left-hand side matrix is transposed rhs_is_trans: Boolean indicating if the right-hand side matrix is transposed scaling_mode: Scaling mode for the GEMM operations out_dtype: Data type of the output tensors has_bias: Boolean indicating if bias tensors are provided - is_grouped_dense_wgrad: Boolean indicating if this is a grouped dense wgrad operation - where both lhs and rhs are 2D matrices and output is (G, M, N) + lhs_axis_boundary: Axis split point for lhs N-D → 2D flattening + rhs_axis_boundary: Axis split point for rhs N-D → 2D flattening + rhs_group_axis: Batch-group axis of rhs to exclude from output non-contracting dims Returns: A jnp.ndarray containing the result of the grouped GEMM operation """ - del lhs_data_aval, rhs_data_aval, bias_aval - del K, lhs_is_trans, rhs_is_trans, has_bias, use_async_d2h_group_sizes + del bias_aval + del has_bias, use_async_d2h_group_sizes + + num_groups = ( + lhs_first_dims_aval.size + or lhs_last_dims_aval.size + or rhs_first_dims_aval.size + or rhs_last_dims_aval.size + or out_first_dims_aval.size + or out_last_dims_aval.size + or additional_args[0].size # alpha (V2) has size G; group_offset (legacy) has size >= 1 + ) + + _assert_grouped_gemm_dims_shapes( + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + num_groups, + ) + + # Derive output shape from N-D buffer shapes using axis_boundary. + lhs_shape = lhs_data_aval.shape + rhs_shape = rhs_data_aval.shape - num_groups = group_sizes_aval.size + # Non-contracting dims for lhs + if lhs_is_trans: + lhs_non_contracting = lhs_shape[lhs_axis_boundary:] + else: + lhs_non_contracting = lhs_shape[:lhs_axis_boundary] + + # Non-contracting dims for rhs (excluding batch-group axis where applicable) + if rhs_is_trans: + rhs_non_contracting = tuple( + rhs_shape[d] + for d in range(rhs_axis_boundary) + if rhs_group_axis is None or d != rhs_group_axis + ) + else: + rhs_non_contracting = rhs_shape[rhs_axis_boundary:] + + # K validation is intentionally skipped: per-group K values may not fill the + # entire buffer (padding is allowed), so sum(rhs_*_dims) != buffer K is acceptable. + if rhs_first_dims_aval.size > 0 or rhs_last_dims_aval.size > 0: + # Wgrad case: rhs has ragged contracting K dimension → output gets G prefix. + out_shape = (num_groups, *lhs_non_contracting, *rhs_non_contracting) + else: + out_shape = (*lhs_non_contracting, *rhs_non_contracting) cublas_workspace_aval = jax.core.ShapedArray( shape=( @@ -1470,9 +1570,6 @@ def abstract( dtype=jnp.uint8, ) - out_shape = (M, N) - if is_grouped_dense_wgrad: - out_shape = (num_groups, M, N) out_aval = jax.core.ShapedArray(shape=out_shape, dtype=out_dtype) if use_v2_ffi: @@ -1480,7 +1577,24 @@ def abstract( shape=(get_grouped_gemm_setup_workspace_size(num_groups),), dtype=jnp.uint8 ) # Temporary buffer for int32 -> int64 conversion of group_sizes on device. - int64_workspace_size = num_groups * jnp.dtype(jnp.int64).itemsize + # Each non-empty *_dims buffer needs its own slot of num_groups int64 elements so that + # make_grouped_tensor can write to a distinct region per ragged dimension. Allocate + # exactly as many slots as there are non-empty buffers (minimum 1 to avoid zero-size). + num_ragged_dim_buffers = sum( + 1 + for aval in [ + lhs_first_dims_aval, + lhs_last_dims_aval, + rhs_first_dims_aval, + rhs_last_dims_aval, + out_first_dims_aval, + out_last_dims_aval, + ] + if aval.size > 0 + ) + int64_workspace_size = ( + max(num_ragged_dim_buffers, 1) * num_groups * jnp.dtype(jnp.int64).itemsize + ) int64_workspace_aval = jax.core.ShapedArray( shape=(int64_workspace_size,), dtype=jnp.uint8 ) @@ -1531,9 +1645,11 @@ def _compute_cublas_workspace_size( workspace_size += lhs_scale_inv_aval.size * tensor_scaling_sinv_aligment workspace_size += rhs_scale_inv_aval.size * tensor_scaling_sinv_aligment elif scaling_mode == ScalingMode.MXFP8_1D_SCALING.value: - # We also pad scale_inv swizzle buffers size for 256 bytes alignment. - workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding - workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + if not use_v2_ffi: + # V1 needs workspace for per-group swizzle output buffers. + # V2: scales are pre-swizzled in JAX, no extra workspace needed. + workspace_size += lhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding + workspace_size += rhs_scale_inv_aval.size + mxfp8_scaling_sinv_alignment_padding return workspace_size @staticmethod @@ -1545,45 +1661,40 @@ def outer_abstract(*args, **kwargs): def lowering( ctx, *args, - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): - del out_dtype + del out_dtype, rhs_group_axis # Python-only; not forwarded to C++ if use_v2_ffi: ffi_name = GroupedGemmPrimitive.name_graph_safe return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) ffi_name = GroupedGemmPrimitive.name return jax.ffi.ffi_lowering(ffi_name)( ctx, *args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, ) @staticmethod @@ -1593,20 +1704,24 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, # group_offset (non-graph-safe) OR alpha (graph-safe) additional_arg_1, # unused placeholder (non-graph-safe) OR beta (graph-safe) - M, - N, - K, lhs_is_trans, rhs_is_trans, scaling_mode, out_dtype, has_bias, - is_grouped_dense_wgrad, use_async_d2h_group_sizes, use_v2_ffi, + lhs_axis_boundary, + rhs_axis_boundary, + rhs_group_axis, ): if GroupedGemmPrimitive.inner_primitive is None: raise RuntimeError("GroupedGemmPrimitive.inner_primitive has not been registered") @@ -1620,19 +1735,23 @@ def impl( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, *additional_args, - M=M, - N=N, - K=K, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + rhs_group_axis=rhs_group_axis, ) return (out,) @@ -1926,23 +2045,65 @@ def _can_use_v2_grouped_gemm( scaling_mode: ScalingMode, dtype: jnp.dtype, has_bias: bool, + lhs_shape=None, + rhs_shape=None, + lhs_axis_boundary=None, + rhs_axis_boundary=None, ) -> bool: """Determine whether the cuda-graphable grouped GEMM implementation can be used based on the input parameters.""" - # Use the cuda-graphable path for plain BF16 non-quantized inputs; fall back to the legacy - # nvte_multi_tensor_gemm path for all other cases (FP8, MXFP8, etc.) to stay - # feature-compatible with the main branch. + # Use the cuda-graphable path for plain BF16 non-quantized inputs and MXFP8; fall back to + # the legacy nvte_multi_tensor_gemm path for all other cases (tensor-scaled FP8, etc.). # Bias can be supported in a kernel or in pure-JAX in the future. if not _v2_grouped_gemm_available: return False - return scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16 and not has_bias + # nvte_grouped_gemm (the v2 kernel) requires SM100+ (Blackwell or newer). + # Fall back to the v1 path on SM90 (Hopper) and older architectures. + if get_device_compute_capability(0) < 100: + return False + + if has_bias: + return False + + if scaling_mode == ScalingMode.NO_SCALING and dtype == jnp.bfloat16: + return True + if scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # V2 MXFP8 requires that the total first dimension of both operands (up to + # axis_boundary) is divisible by 128, matching the quantize V2 kernel requirement. + # Individual group sizes must also be 128-aligned (dynamic constraint). + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_first_dim = math.prod(lhs_shape[:lhs_axis_boundary]) + if lhs_first_dim % 128 != 0: + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_first_dim = math.prod(rhs_shape[:rhs_axis_boundary]) + if rhs_first_dim % 128 != 0: + return False + # V2 MXFP8 also requires that the "last" dimension (after axis_boundary) of both + # operands is a multiple of 128. The V2 GEMM setup kernel computes per-group + # scale pointers as ``data_offset / 32``, which equals ``K_blocks * last_dim``. + # The quantize kernel, however, pads the colwise scale stride to + # ``ceil(last_dim / 128) * 128``, making per-group padded scale larger than + # ``K_blocks * last_dim`` when ``last_dim`` is not 128-aligned. This causes + # adjacent groups' scales to overlap in the flat buffer. Fall back to V1 (which + # swizzles per-group scales individually) when the condition is not met. + if lhs_shape is not None and lhs_axis_boundary is not None: + lhs_last_dim = math.prod(lhs_shape[lhs_axis_boundary:]) + if lhs_last_dim % 128 != 0: + return False + if rhs_shape is not None and rhs_axis_boundary is not None: + rhs_last_dim = math.prod(rhs_shape[rhs_axis_boundary:]) + if rhs_last_dim % 128 != 0: + return False + return True + + return False def grouped_gemm( - lhs: Union[jnp.ndarray, GroupedScaledTensor1x], - rhs: Union[jnp.ndarray, GroupedScaledTensor1x], - group_sizes: jnp.ndarray, + lhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], + rhs: Union[GroupedNoScaleTensor, GroupedScaledTensor1x], contracting_dims: Tuple[Sequence[int], Sequence[int]] = ((1,), (2,)), bias: jnp.ndarray = None, precision: jax.lax.Precision = jax.lax.Precision.DEFAULT, @@ -1955,9 +2116,8 @@ def grouped_gemm( Grouped GEMM operation. Args: - lhs: Left-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - rhs: Right-hand side input matrix, can be a jnp.ndarray or GroupedScaledTensor1x - group_sizes: 1D array containing the sizes of each group + lhs: Left-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x + rhs: Right-hand side input matrix, GroupedNoScaleTensor or GroupedScaledTensor1x contracting_dims: Tuple of two sequences representing the contracting dimensions bias: Bias tensor of shape (G, N) precision: JAX precision for the GEMM operation @@ -1967,49 +2127,74 @@ def grouped_gemm( Returns: A jnp.ndarray containing the result of the grouped GEMM operation - - Note: - Tested shapes: - lhs: [M, K] or [K, N] - rhs: [G, N, K] or [G, K, N] or [G * K, N] or [N, G * K] """ # TODO(Phuong): implement the precision del precision - if isinstance(lhs, jnp.ndarray): - if not isinstance(rhs, jnp.ndarray): - raise TypeError( - f"Expected rhs to be jnp.ndarray when lhs is jnp.ndarray, but got type={type(rhs)}" - ) - out_dtype = lhs.dtype - lhs_shape = lhs.shape - rhs_shape = rhs.shape - lhs_data = lhs - rhs_data = rhs - lhs_scale_inv = rhs_scale_inv = jnp.empty((0,), jnp.float32) + empty_gs = jnp.empty((0,), jnp.int32) + + # Extract data, dims, and metadata from tensor objects. + if isinstance(lhs, GroupedNoScaleTensor): + lhs_data = lhs.data + lhs_shape = lhs.original_shape + lhs_scale_inv = jnp.empty((0,), jnp.float32) scaling_mode = ScalingMode.NO_SCALING + out_dtype = lhs.data.dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + rhs_group_axis = getattr(rhs, "group_axis", 0) elif isinstance(lhs, GroupedScaledTensor1x): - if not isinstance(rhs, GroupedScaledTensor1x): - raise TypeError( - "Expected rhs to be GroupedScaledTensor1x when lhs is GroupedScaledTensor1x, but" - f" got type={type(rhs)}" - ) - out_dtype = lhs.dq_dtype lhs_shape = lhs.original_shape - rhs_shape = rhs.original_shape - lhs_data = lhs.data - rhs_data = rhs.data + lhs_data = lhs.data.reshape(lhs_shape) lhs_scale_inv = lhs.scale_inv + scaling_mode = lhs.scaling_mode + out_dtype = lhs.dq_dtype + lhs_first_dims = lhs.first_dims if lhs.first_dims is not None else empty_gs + lhs_last_dims = lhs.last_dims if lhs.last_dims is not None else empty_gs + rhs_group_axis = getattr(rhs, "group_axis", 0) + else: + raise TypeError( + f"lhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(lhs)}" + ) + + if isinstance(rhs, GroupedNoScaleTensor): + rhs_data = rhs.data + rhs_shape = rhs.original_shape + rhs_scale_inv = jnp.empty((0,), jnp.float32) + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + elif isinstance(rhs, GroupedScaledTensor1x): + rhs_shape = rhs.original_shape + rhs_data = rhs.data.reshape(rhs_shape) rhs_scale_inv = rhs.scale_inv - if lhs.scaling_mode != rhs.scaling_mode: + rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs + rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs + if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: raise ValueError( f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," f" rhs.scaling_mode={rhs.scaling_mode}" ) - scaling_mode = lhs.scaling_mode + if isinstance(lhs, GroupedScaledTensor1x): + scaling_mode = lhs.scaling_mode + else: + raise TypeError( + f"rhs must be GroupedNoScaleTensor or GroupedScaledTensor1x, got type={type(rhs)}" + ) + + # Infer output dims from which operand has the ragged non-contracting dim. + if rhs_first_dims.size > 0 or rhs_last_dims.size > 0: + # Wgrad: rhs contracting dim is ragged → output is uniform (G prefix from num_groups) + out_first_dims = empty_gs + out_last_dims = empty_gs + elif lhs_first_dims.size > 0: + out_first_dims = lhs_first_dims + out_last_dims = empty_gs + elif lhs_last_dims.size > 0: + out_first_dims = empty_gs + out_last_dims = lhs_last_dims else: - raise TypeError("Unsupported lhs type object!") + out_first_dims = out_last_dims = empty_gs out_dtype = preferred_element_type or out_dtype @@ -2018,26 +2203,10 @@ def grouped_gemm( lhs_is_trans = lhs_contract_dim[-1] != len(lhs_shape) - 1 lhs_flatten_axis = len(lhs_contract_dim) * (1 if lhs_is_trans else -1) - # rhs_shape [G, K, N] - rhs_is_trans = rhs_contract_dim[0] != 1 + # rhs_is_trans: K is the last dim of rhs (i.e., rhs is in "T" layout). + rhs_is_trans = rhs_contract_dim[-1] == len(rhs_shape) - 1 rhs_flatten_axis = -len(rhs_contract_dim) if rhs_is_trans else 1 + len(rhs_contract_dim) - is_grouped_dense_wgrad = False - if len(rhs_shape) == 2: - rhs_is_trans = rhs_contract_dim[0] != 0 - is_grouped_dense_wgrad = True - - # TODO(Hua): thses are for fp16 dense wgrad, any better way to handle this? - if ( - is_grouped_dense_wgrad - and not isinstance(lhs, ScaledTensor) - and not isinstance(rhs, ScaledTensor) - ): - lhs_is_trans = True - rhs_is_trans = False - lhs_flatten_axis = 1 - rhs_flatten_axis = 1 - if ( not isinstance(lhs, ScaledTensor) and not isinstance(rhs, ScaledTensor) @@ -2068,12 +2237,24 @@ def grouped_gemm( quantizer_set.kernel.q_layout = ( QuantizeLayout.ROWWISE if rhs_is_rowwise else QuantizeLayout.COLWISE ) - lhs_q = grouped_quantize(lhs, quantizer_set.x, group_sizes, lhs_flatten_axis) + active_group_sizes = next( + ( + gs + for gs in [lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims] + if gs.size > 0 + ), + empty_gs, + ) + lhs_input_data = lhs.data if isinstance(lhs, GroupedNoScaleTensor) else lhs_data + rhs_input_data = rhs.data if isinstance(rhs, GroupedNoScaleTensor) else rhs_data + lhs_q = grouped_quantize( + lhs_input_data, quantizer_set.x, active_group_sizes, lhs_flatten_axis + ) rhs_q = grouped_quantize( - rhs, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis + rhs_input_data, quantizer_set.kernel, group_sizes=None, flatten_axis=rhs_flatten_axis ) - lhs_data = lhs_q.data - rhs_data = rhs_q.data + lhs_data = lhs_q.data.reshape(lhs_q.original_shape) + rhs_data = rhs_q.data.reshape(rhs_q.original_shape) lhs_scale_inv = lhs_q.scale_inv rhs_scale_inv = rhs_q.scale_inv lhs_shape = lhs_q.original_shape @@ -2105,38 +2286,48 @@ def grouped_gemm( lhs_contract_dim = tuple((lhs_ndim - 1 - i) % lhs_ndim for i in lhs_contract_dim) if rhs_layout_is_T: # For rhs [G, K, N], need to exclude the G dim from contract_dim - if group_sizes.size == rhs_shape[0]: + if ( + lhs_first_dims.size > 0 or lhs_last_dims.size > 0 + ): # fwd/dgrad: rhs has G as first dim rhs_contract_dim = tuple( (rhs_ndim - 1 - i) % (rhs_ndim - 1) + 1 for i in rhs_contract_dim ) else: rhs_contract_dim = tuple((rhs_ndim - 1 - i) % rhs_ndim for i in rhs_contract_dim) - # Calling GroupedGEMM Custom Call - K_lhs = math.prod(lhs_shape[i] for i in lhs_contract_dim) - K_rhs = math.prod(rhs_shape[i] for i in rhs_contract_dim) - if K_lhs != K_rhs: + # Compute N-D axis boundaries from final (post-adjustment) contracting dims. + lhs_axis_boundary = get_lhs_axis_boundary(lhs_contract_dim, lhs_is_trans) + rhs_axis_boundary = get_rhs_axis_boundary(rhs_contract_dim, rhs_is_trans) + + num_gemms = ( + lhs_first_dims.size + or lhs_last_dims.size + or rhs_first_dims.size + or rhs_last_dims.size + or out_first_dims.size + or out_last_dims.size + ) + if num_gemms == 0: raise ValueError( - f"Mismatched contracting dimensions: K_lhs={K_lhs}, K_rhs={K_rhs} (from" - f" lhs_shape={lhs_shape}, rhs_shape={rhs_shape})" + "grouped_gemm requires at least one non-empty dimension array. " + "Ensure lhs or rhs tensor objects carry first_dims or last_dims." ) - M = math.prod(_calculate_remaining_shape(lhs_shape, lhs_contract_dim)) - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)[1:]) # Exclude G - - if is_grouped_dense_wgrad: - N = math.prod(_calculate_remaining_shape(rhs_shape, rhs_contract_dim)) - else: - if group_sizes.size != rhs_shape[0]: - raise ValueError( - "Expected group_sizes.size == rhs_shape[0], but got" - f" group_sizes.size={group_sizes.size}, rhs_shape[0]={rhs_shape[0]}" - ) has_bias = bias is not None - if has_bias and bias.shape != (group_sizes.size, N): - raise ValueError( - f"Expected bias.shape=({group_sizes.size}, {N}), but got bias.shape={bias.shape}" - ) + if has_bias: + # Compute N from rhs non-contracting dims. + if rhs_is_trans: + N_dim = math.prod( + rhs_data.shape[d] + for d in range(rhs_axis_boundary) + if rhs_group_axis is None or d != rhs_group_axis + ) + else: + N_dim = math.prod(rhs_data.shape[rhs_axis_boundary:]) + assert bias.shape == ( + num_gemms, + N_dim, + ), f"bias shape {bias.shape} does not match expected shape {(num_gemms, N_dim)}" bias = jnp.empty((), jnp.float32) if bias is None else bias if group_offset is not None: @@ -2146,9 +2337,48 @@ def grouped_gemm( " and padded with zeros to not affect the result of the MoE block." ) - use_v2_ffi = _can_use_v2_grouped_gemm(scaling_mode, lhs_data.dtype, has_bias) + use_v2_ffi = _can_use_v2_grouped_gemm( + scaling_mode, + lhs_data.dtype, + has_bias, + lhs_shape=lhs_data.shape, + rhs_shape=rhs_data.shape, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + ) + if use_v2_ffi and scaling_mode == ScalingMode.MXFP8_1D_SCALING: + # Pre-swizzle full scale tensors in JAX (CUDA-graph safe). + # Grouped scale_inv is 1D (flat, worst-case padded). When all group sizes are + # multiples of 128 (V2 requirement), the per-group scales are contiguous with no + # inter-group padding gaps. We reshape the meaningful prefix to 2D, swizzle, and + # write it back into the original 1D buffer (extra trailing zeros stay untouched). + lhs_is_colwise = lhs_is_trans + rhs_is_colwise = not rhs_is_trans + lhs_scale_shape = scaling_mode.get_scale_shape( + lhs_data.shape, + is_colwise=lhs_is_colwise, + is_padded=True, + flatten_axis=lhs_axis_boundary, + ) + rhs_scale_shape = scaling_mode.get_scale_shape( + rhs_data.shape, + is_colwise=rhs_is_colwise, + is_padded=True, + flatten_axis=rhs_axis_boundary, + ) + # get_scale_shape may return a multi-dim shape (e.g. (8, 4, 128) for a 3D + # input), but _swizzle_grouped_scale needs a flat 2D shape (rows, cols) where + # cols = n_block_y (last dim) and rows = prod(all other dims). This correctly + # flattens the group/K-block axes into a single row dimension so the swizzle + # pattern operates on the full (K-blocks-across-groups × N-blocks) matrix. + lhs_n_block_y = lhs_scale_shape[-1] + rhs_n_block_y = rhs_scale_shape[-1] + lhs_scale_2d = (math.prod(lhs_scale_shape) // lhs_n_block_y, lhs_n_block_y) + rhs_scale_2d = (math.prod(rhs_scale_shape) // rhs_n_block_y, rhs_n_block_y) + lhs_scale_inv = _swizzle_grouped_scale(lhs_scale_inv, lhs_scale_2d, lhs_is_colwise) + rhs_scale_inv = _swizzle_grouped_scale(rhs_scale_inv, rhs_scale_2d, rhs_is_colwise) + if use_v2_ffi: - num_gemms = group_sizes.shape[0] additional_arg_0 = jnp.ones((num_gemms,), jnp.float32) # alpha additional_arg_1 = jnp.zeros((num_gemms,), jnp.float32) # beta else: @@ -2161,19 +2391,23 @@ def grouped_gemm( rhs_data, rhs_scale_inv, bias, - group_sizes, + lhs_first_dims, + lhs_last_dims, + rhs_first_dims, + rhs_last_dims, + out_first_dims, + out_last_dims, additional_arg_0, additional_arg_1, - M=M, - N=N, - K=K_lhs, lhs_is_trans=lhs_is_trans, rhs_is_trans=rhs_is_trans, scaling_mode=scaling_mode.value, out_dtype=out_dtype, has_bias=has_bias, - is_grouped_dense_wgrad=is_grouped_dense_wgrad, use_async_d2h_group_sizes=use_async_d2h_group_sizes, use_v2_ffi=use_v2_ffi, + lhs_axis_boundary=lhs_axis_boundary, + rhs_axis_boundary=rhs_axis_boundary, + rhs_group_axis=rhs_group_axis, ) return out diff --git a/transformer_engine/jax/cpp_extensions/quantization.py b/transformer_engine/jax/cpp_extensions/quantization.py index bf4e833c89..a2da6b8830 100644 --- a/transformer_engine/jax/cpp_extensions/quantization.py +++ b/transformer_engine/jax/cpp_extensions/quantization.py @@ -51,6 +51,53 @@ __all__ = ["quantize", "quantize_dbias", "grouped_quantize", "grouped_dbias"] +def _build_scale_spec(x_spec, scale_shape, mesh): + """Build a PartitionSpec for the MXFP8 scale tensor compatible with its shape. + + The scale tensor has smaller dimensions than the data tensor (each dimension + divided by the MXFP8 block size). This function ensures that we only shard a + scale dimension by a mesh axis (or tuple of axes) if scale_shape[i] is + divisible by the total axis size. If not, a ValueError is raised with a + helpful diagnostic message. + """ + result = [] + for axis, scale_dim in zip(x_spec, scale_shape): + if axis is None: + result.append(None) + elif isinstance(axis, str): + axis_size = mesh.shape.get(axis, 1) + if scale_dim % axis_size == 0: + result.append(axis) + else: + raise ValueError( + f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " + f"by mesh axis '{axis}' of size {axis_size}: " + f"scale dim {scale_dim} is not divisible by {axis_size}. " + "The data tensor's sharding is incompatible with the MXFP8 block " + "size along this axis. Try reducing expert parallelism (EP) so that " + "EP divides the scale dimension, or increase the tensor size." + ) + elif isinstance(axis, (tuple, list)): + # Multi-axis sharding (e.g. ('fsdp', 'expert')): check total combined size. + total_size = 1 + for a in axis: + total_size *= mesh.shape.get(a, 1) + if scale_dim % total_size == 0: + result.append(axis) + else: + raise ValueError( + f"Cannot partition MXFP8 scale tensor (shape={tuple(scale_shape)}) " + f"by mesh axes {tuple(axis)} of combined size {total_size}: " + f"scale dim {scale_dim} is not divisible by {total_size}. " + "The data tensor's sharding is incompatible with the MXFP8 block " + "size along this axis. Try reducing parallelism or increasing the " + "tensor size." + ) + else: + result.append(None) + return tuple(result) + + class BaseDBiasQuantizePrimitive(BasePrimitive): """ Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias @@ -446,7 +493,13 @@ def infer_sharding_from_operands( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( + arg_infos[0].shape, + is_padded=False, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) if q_layout.has_colwise: if ( @@ -528,7 +581,13 @@ def partition( scale_inv_spec = colwise_scale_inv_spec = (None,) if ScalingMode(scaling_mode).is_block_scaling: - scale_inv_spec = x_spec + rowwise_scale_shape, _ = ScalingMode(scaling_mode).get_scale_shape_2x( + arg_infos[0].shape, + is_padded=False, + flatten_axis=flatten_axis, + broadcast_2d_scale_shape_to_1d=True, + ) + scale_inv_spec = _build_scale_spec(x_spec, rowwise_scale_shape, mesh) if q_layout.has_colwise: if ( @@ -993,7 +1052,8 @@ class GroupedQuantizePrimitive(BasePrimitive): Cast Primitive wrapping nvte_quantize and nvte_quantize_dbias """ - name = "te_grouped_quantize_ffi" + name = "te_grouped_quantize_ffi" # V1: fallback path (supports all shapes, not CUDA-graph safe) + name_v2 = "te_grouped_quantize_v2_ffi" # V2: MXFP8, CUDA-graph safe multiple_results = True impl_static_args = ( 3, @@ -1006,6 +1066,56 @@ class GroupedQuantizePrimitive(BasePrimitive): inner_primitive = None outer_primitive = None + @staticmethod + def _use_v2_kernel(scaling_mode, x_shape, flatten_axis): + """Return True when the V2 (CUDA-graph-safe) MXFP8 kernel can be used. + + V2 requires: + 1. The total first logical dimension (product of x_shape up to flatten_axis) + is divisible by 128. + 2. For multi-dim group tensors (eff > 1, e.g., kernel shape G×K×N), the + per-group row count non_group_m = prod(x_shape[1:eff]) must also be + divisible by 128 (because group_sizes[i] counts slices, not rows, and + actual rows per group = group_sizes[i] * non_group_m). + 3. For lhs-style tensors (eff == 1, shape M×K), individual group sizes must + be 128-aligned -- this is a dynamic constraint assumed by the caller. + + Falls back to V1 when constraints are not met. V1 supports arbitrary shapes + but performs a D2H copy of group_sizes (not CUDA-graph safe). + """ + if ScalingMode(scaling_mode) != ScalingMode.MXFP8_1D_SCALING: + assert False, ( + "V2 grouped quantize kernel currently only supports MXFP8 1D scaling mode, but got" + " scaling_mode {}".format(scaling_mode) + ) + return False + ndim = len(x_shape) + eff = flatten_axis if flatten_axis >= 0 else flatten_axis + ndim + total_first_dim = math.prod(x_shape[:eff]) + if total_first_dim % 128 != 0: + assert False, ( + "V2 grouped quantize kernel requires total first logical dimension (product of" + " x_shape up to flatten_axis) to be divisible by 128, but got shape {} and" + " flatten_axis {} with total_first_dim {}".format( + x_shape, flatten_axis, total_first_dim + ) + ) + return False + # For multi-dim group tensors (e.g., kernel shape G×K×N with eff=2), + # non_group_m = K must also be 128-aligned. + if eff > 1: + non_group_m = math.prod(x_shape[1:eff]) + if non_group_m % 128 != 0: + assert False, ( + "V2 grouped quantize kernel requires non-group dimension (product of" + " x_shape[1:flatten_axis]) to be divisible by 128 for multi-dim group tensors," + " but got shape {} and flatten_axis {} with non_group_m {}".format( + x_shape, flatten_axis, non_group_m + ) + ) + return False + return True + @staticmethod def abstract( x_aval, @@ -1050,7 +1160,16 @@ def abstract( rowwise_scale_inv_shape = (1,) rowwise_out_aval = jax.core.ShapedArray(shape=rowwise_out_shape, dtype=out_dtype) - amax_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2 path: 5th output is int64_workspace laid out as: + # [n_groups int64 group_sizes | n_groups+1 int64 offsets] + # = (2*n_groups + 1) * sizeof(int64_t) bytes stored as uint8. + n_groups = group_sizes_aval.size + fifth_out_aval = jax.core.ShapedArray(shape=((2 * n_groups + 1) * 8,), dtype=jnp.uint8) + else: + # V1 path: 5th output is amax + fifth_out_aval = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) if q_layout.has_colwise: colwise_out_shape = out_shape @@ -1070,7 +1189,7 @@ def abstract( colwise_out_aval, rowwise_scale_inv_aval, colwise_scale_inv_aval, - amax_aval, + fifth_out_aval, ) @staticmethod @@ -1084,9 +1203,17 @@ def outer_abstract(*args, **kwargs): colwise_out, scale_inv, colwise_scale_inv, - updated_amax, + fifth_out, ) = GroupedQuantizePrimitive.abstract(*args, **kwargs) - return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, updated_amax + # When V2 is used, the inner abstract returns int64_workspace as the 5th output. + # The outer interface always presents amax (float32, n_groups) for a consistent API. + scaling_mode = kwargs.get("scaling_mode") + x_aval = args[0] + group_sizes_aval = args[2] + flatten_axis = kwargs.get("flatten_axis") + if GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis): + fifth_out = jax.core.ShapedArray(shape=(group_sizes_aval.size,), dtype=jnp.float32) + return rowwise_out, colwise_out, scale_inv, colwise_scale_inv, fifth_out @staticmethod def lowering( @@ -1111,6 +1238,21 @@ def lowering( assert scale_aval.dtype == jnp.float32 assert group_sizes_aval.dtype == jnp.int32 assert group_axis == 0 + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x_aval.shape, flatten_axis) + if use_v2: + # V2: CUDA-graph safe; scale is passed but ignored by the C++ handler. + # Requires total_first_dim % 128 == 0 (checked above) and all individual + # group sizes % 128 == 0 (dynamic constraint, enforced by the kernel). + return ffi.ffi_lowering(GroupedQuantizePrimitive.name_v2)( + ctx, + x, + scale, + group_sizes, + q_layout=q_layout.value.value, + flatten_axis=flatten_axis, + ) + # V1: supports arbitrary shapes but not CUDA-graph safe (performs D2H copy of group_sizes). + # Used for non-MXFP8 scaling modes and for MXFP8 when total_first_dim % 128 != 0. return ffi.ffi_lowering(GroupedQuantizePrimitive.name)( ctx, x, @@ -1142,7 +1284,7 @@ def impl( colwise_out, rowwise_scale_inv, colwise_scale_inv, - updated_amax, + fifth, ) = GroupedQuantizePrimitive.inner_primitive.bind( x, scale, @@ -1154,6 +1296,12 @@ def impl( group_axis=group_axis, scale_dtype=scale_dtype, ) + use_v2 = GroupedQuantizePrimitive._use_v2_kernel(scaling_mode, x.shape, flatten_axis) + if use_v2: + # fifth is int64_workspace; return a dummy zero amax for interface compatibility + updated_amax = jnp.zeros((group_sizes.size,), jnp.float32) + else: + updated_amax = fifth return (rowwise_out, colwise_out, rowwise_scale_inv, colwise_scale_inv, updated_amax) @@ -1203,6 +1351,7 @@ def grouped_quantize( ), f"Only flatten_axis = -1 is supported for now, got {flatten_axis}" group_axis = 0 + ragged_first_dims = group_sizes # None if no explicit group_sizes (kernel case) if group_sizes is None: group_sizes = jnp.ones(x.shape[group_axis], dtype=jnp.int32) @@ -1280,7 +1429,7 @@ def grouped_quantize( q_layout=quantizer.q_layout, data_layout=quantizer.get_data_layout(), flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=ragged_first_dims, original_shape=original_shape, group_axis=group_axis, ) diff --git a/transformer_engine/jax/csrc/extensions.h b/transformer_engine/jax/csrc/extensions.h index 0fe4e99239..c832b4ebb2 100644 --- a/transformer_engine/jax/csrc/extensions.h +++ b/transformer_engine/jax/csrc/extensions.h @@ -55,6 +55,24 @@ struct GemmConfig { bool use_split_accumulator; }; +struct GroupedGemmV2Config { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; +}; + +struct GroupedGemmConfig { + bool lhs_is_trans; + bool rhs_is_trans; + JAXX_Scaling_Mode scaling_mode; + bool has_bias; + bool use_async_d2h_group_sizes; + int64_t lhs_axis_boundary; + int64_t rhs_axis_boundary; +}; + inline bool use_fp8(DType type) { return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2; } // Activation @@ -93,6 +111,8 @@ XLA_FFI_DECLARE_HANDLER_SYMBOL(DBiasQuantizeHandler); XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeHandler); +XLA_FFI_DECLARE_HANDLER_SYMBOL(GroupedQuantizeV2Handler); + XLA_FFI_DECLARE_HANDLER_SYMBOL(DequantizeHandler); pybind11::tuple GetDBiasQuantizeWorkspaceSizes(size_t batch_size, size_t hidden_size, @@ -192,6 +212,22 @@ XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( ::xla::ffi::StructMember("rhs_transposed"), ::xla::ffi::StructMember("use_split_accumulator")); +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmV2Config, ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary")); + +XLA_FFI_REGISTER_STRUCT_ATTR_DECODING( + transformer_engine::jax::GroupedGemmConfig, ::xla::ffi::StructMember("lhs_is_trans"), + ::xla::ffi::StructMember("rhs_is_trans"), + ::xla::ffi::StructMember("scaling_mode"), + ::xla::ffi::StructMember("has_bias"), + ::xla::ffi::StructMember("use_async_d2h_group_sizes"), + ::xla::ffi::StructMember("lhs_axis_boundary"), + ::xla::ffi::StructMember("rhs_axis_boundary")); + // ENUM_ATTR and DICT_ATTR recoding need to be registered in the global namespace XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Scaling_Mode); XLA_FFI_REGISTER_ENUM_ATTR_DECODING(transformer_engine::jax::JAXX_Score_Function); diff --git a/transformer_engine/jax/csrc/extensions/gemm.cpp b/transformer_engine/jax/csrc/extensions/gemm.cpp index 2acefa2d30..45625120fd 100644 --- a/transformer_engine/jax/csrc/extensions/gemm.cpp +++ b/transformer_engine/jax/csrc/extensions/gemm.cpp @@ -481,6 +481,8 @@ class JAXX_GroupedTensorWrapper { m_grouped_tensor(other.m_grouped_tensor), m_data_tensor(other.m_data_tensor), m_scale_inv_tensor(other.m_scale_inv_tensor), + m_colwise_data_tensor(other.m_colwise_data_tensor), + m_colwise_scale_inv_tensor(other.m_colwise_scale_inv_tensor), m_sizes_tensor(other.m_sizes_tensor), m_offsets_tensor(other.m_offsets_tensor) { other.m_grouped_tensor = nullptr; @@ -489,6 +491,8 @@ class JAXX_GroupedTensorWrapper { ~JAXX_GroupedTensorWrapper(); void set_rowwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_columnwise(Buffer_Type const &data, std::optional const &scale_inv); + void set_with_gemm_swizzled_scales(bool val); void set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name); // Set only group sizes (no offsets); the setup kernel will compute offsets from sizes. @@ -505,6 +509,8 @@ class JAXX_GroupedTensorWrapper { // Internal tensors. These need to be kept alive as long as the grouped tensor is alive. NVTEBasicTensor m_data_tensor{}; NVTEBasicTensor m_scale_inv_tensor{}; + NVTEBasicTensor m_colwise_data_tensor{}; + NVTEBasicTensor m_colwise_scale_inv_tensor{}; NVTEBasicTensor m_sizes_tensor{}; NVTEBasicTensor m_offsets_tensor{}; @@ -556,6 +562,45 @@ void JAXX_GroupedTensorWrapper::set_rowwise(Buffer_Type const &data, } } +void JAXX_GroupedTensorWrapper::set_columnwise(Buffer_Type const &data, + std::optional const &scale_inv) { + NVTEDType data_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(data.element_type())); + m_colwise_data_tensor = + NVTEBasicTensor{reinterpret_cast(data.untyped_data()), data_dtype, m_data_shape}; + + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseData, + &m_colwise_data_tensor, sizeof(m_colwise_data_tensor)); + + if (scale_inv.has_value()) { + NVTEDType scale_inv_dtype = + static_cast(convert_ffi_datatype_to_te_dtype(scale_inv->element_type())); + NVTEShape logical_scale_shape{}; + if (scale_inv->dimensions().size() == 1) { + logical_scale_shape.ndim = 1; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + } else if (scale_inv->dimensions().size() == 2) { + logical_scale_shape.ndim = 2; + logical_scale_shape.data[0] = scale_inv->dimensions()[0]; + logical_scale_shape.data[1] = scale_inv->dimensions()[1]; + } else { + NVTE_CHECK(false, "Expected 1D or 2D tensor for GEMM columnwise scale_inv but received ndim=", + scale_inv->dimensions().size()); + } + m_colwise_scale_inv_tensor = + NVTEBasicTensor{reinterpret_cast(scale_inv->untyped_data()), scale_inv_dtype, + logical_scale_shape}; + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedColumnwiseScaleInv, + &m_colwise_scale_inv_tensor, sizeof(m_colwise_scale_inv_tensor)); + } +} + +void JAXX_GroupedTensorWrapper::set_with_gemm_swizzled_scales(bool val) { + auto v = static_cast(val); + nvte_set_grouped_tensor_param(m_grouped_tensor, kNVTEGroupedWithGEMMSwizzledScales, &v, + sizeof(v)); +} + void JAXX_GroupedTensorWrapper::set_group_info(Buffer_Type const &group_sizes, Buffer_Type const &group_offsets, NVTEGroupedTensorParam group_sizes_param_name) { @@ -619,137 +664,151 @@ JAXX_GroupedTensorWrapper make_grouped_tensor(Buffer_Type const &data, return std::move(grouped_tensor_wrapper); } -// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. -Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, - Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type alpha, Buffer_Type beta, - Result_Type output, Result_Type cublas_workspace, - Result_Type setup_workspace, Result_Type int64_workspace, size_t m, - size_t n, size_t k, bool lhs_is_trans, bool rhs_is_trans, - JAXX_Scaling_Mode scaling_mode, bool is_grouped_dense_wgrad) { - // Notes on matrix layouts and transpose: - // Jax uses row-major data_layout, on entering this function, each input matrix pair: - // A: row-major [m, k] for N - [k, m] for T - // B: row-major [k, n] for N - [n, k] for T - // on exiting this function, JAX expect: - // C: row-major with size [m, n]. - // cuBLAS uses column-major data_layout, in this view, each input matrix pair: - // A: column-major with size [k, m] for T - [m, k] for N - // B: column-major with size [n, k] for T - [k, n] for N - // - // If we call cuBLAS GEMM for A * B, the output will be: - // C: column-major with size [m, n] --> row-major with size [n, m]. - // To make the output compatible with JAX, we need to swap A and B in cuBLAS GEMM call. - - // Inputs - auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); - auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); - auto lhs_sinv_ptr = reinterpret_cast(lhs_sinv.untyped_data()); - auto rhs_sinv_ptr = reinterpret_cast(rhs_sinv.untyped_data()); - auto lhs_dtype = convert_ffi_datatype_to_te_dtype(lhs_data.element_type()); - auto rhs_dtype = convert_ffi_datatype_to_te_dtype(rhs_data.element_type()); - auto lhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(lhs_sinv.element_type()); - auto rhs_sinv_dtype = convert_ffi_datatype_to_te_dtype(rhs_sinv.element_type()); - bool has_bias = product(bias.dimensions()) > 0; - auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; - auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - - // Convert int32 group_sizes to int64 into the dedicated output buffer. - NVTE_CHECK(group_sizes.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); - auto *int64_sizes_ptr = reinterpret_cast(int64_workspace->untyped_data()); - nvte_convert_int32_to_int64(reinterpret_cast(group_sizes.untyped_data()), - int64_sizes_ptr, num_gemms, stream); - - NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING, - "Only non-quantized grouped GEMM is supported in current implementation."); +// V2 variant (NO_SCALING): derives data shape from the XLA buffer directly, converts group_sizes +// int32→int64 per-tensor into a dedicated slot of int64_workspace, and wires first_dims/last_dims. +// int64_offset (in int64 elements) is updated on return to the next available slot so callers can +// thread it through successive make_grouped_tensor calls without aliasing. Bounds are checked +// before each slot is used. Only NO_SCALING is supported by this overload. +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + // Flatten dims at axis_boundary to produce a 2D NVTE shape. + // axis_boundary=-1 (default) collapses dims[0..N-2] → rows and keeps dims[N-1] → cols, + // preserving the prior behaviour for output buffers (e.g. [G, K, N] for wgrad). + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(JAXX_Scaling_Mode::NO_SCALING, num_gemms, dataShape); + wrapper.set_rowwise(data, std::nullopt); + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} - // It is weird that TE/Common GEMM only use colwise for MXFP8 - const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); - const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || - scaling_mode == JAXX_Scaling_Mode::CURRENT_TENSOR_SCALING; - const bool is_mxfp8_scaling = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; - const bool rhs_use_colwise = is_mxfp8_scaling && !rhs_is_trans; - const bool lhs_use_colwise = is_mxfp8_scaling && lhs_is_trans; +// V2 variant with scaling support (MXFP8 or NO_SCALING). Accepts scale_inv buffer and +// use_colwise flag to wire rowwise or columnwise data+scales for the grouped tensor. +// Pre-swizzled scales are indicated via set_with_gemm_swizzled_scales(true). +JAXX_GroupedTensorWrapper make_grouped_tensor( + Buffer_Type const &data, Buffer_Type const &scale_inv, JAXX_Scaling_Mode scaling_mode, + bool use_colwise, Buffer_Type const &first_dims, Buffer_Type const &last_dims, + int64_t *int64_workspace_base, size_t int64_workspace_capacity, size_t &int64_offset, + size_t num_gemms, cudaStream_t stream, int64_t axis_boundary = -1) { + auto dims = data.dimensions(); + NVTE_CHECK(dims.size() >= 2, "grouped GEMM data buffer must be at least 2D."); + size_t ab = (axis_boundary < 0) ? dims.size() - 1 : static_cast(axis_boundary); + NVTEShape dataShape{.data = {product(dims, 0, ab), product(dims, ab, dims.size())}, .ndim = 2}; + JAXX_GroupedTensorWrapper wrapper(scaling_mode, num_gemms, dataShape); + + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + if (is_mxfp8 && use_colwise) { + wrapper.set_columnwise(data, scale_inv); + } else if (is_mxfp8) { + wrapper.set_rowwise(data, scale_inv); + } else { + // NO_SCALING: no scale_inv needed + wrapper.set_rowwise(data, std::nullopt); + } + if (is_mxfp8) { + wrapper.set_with_gemm_swizzled_scales(true); + } - // Outputs - auto out_ptr = reinterpret_cast(output->untyped_data()); - auto out_dtype = convert_ffi_datatype_to_te_dtype(output->element_type()); - auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); - // Here we clear the lower 8 bits of the buffer address to ensure the buffer is 256-aligned - auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); - cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); - auto workspace_total_size = product(cublas_workspace->dimensions()); + if (first_dims.element_count() > 0) { + NVTE_CHECK(first_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for first_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(first_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedFirstDims); + int64_offset += num_gemms; + } + if (last_dims.element_count() > 0) { + NVTE_CHECK(last_dims.element_type() == xla::ffi::DataType::S32, "group_sizes must be int32."); + NVTE_CHECK(int64_offset + num_gemms <= int64_workspace_capacity, + "int64_workspace overflow: not enough space for last_dims conversion."); + auto *slot = int64_workspace_base + int64_offset; + nvte_convert_int32_to_int64(reinterpret_cast(last_dims.untyped_data()), slot, + num_gemms, stream); + wrapper.set_group_sizes_only(slot, num_gemms, kNVTEGroupedLastDims); + int64_offset += num_gemms; + } + return wrapper; +} - auto lhs_sinv_size = product(lhs_sinv.dimensions()); - auto rhs_sinv_size = product(rhs_sinv.dimensions()); - const size_t workspace_alignment_padding = 256; - const size_t tensor_scaling_sinv_aligment = 16; - const size_t mxfp8_scaling_sinv_alignment_padding = 256; - auto workspace_size = workspace_total_size - workspace_alignment_padding; - if (is_mxfp8_scaling) { - // For MXFP8 swizzled scale_inv buffers, only the first pointer needs to be with 256B alignment padding. Later pointers are guaranteed to be 256-aligned as the scale_inv shapes are padded by 128x4. - workspace_size -= (lhs_sinv_size + rhs_sinv_size + 2 * mxfp8_scaling_sinv_alignment_padding); - } else if (is_tensor_scaling) { - // For tensor scaling, each matrix has a single scale value, and all scales need to be aligned - // by 16 bytes to meet the requirement of CUDA 12.9.1 and later. - workspace_size -= tensor_scaling_sinv_aligment * (lhs_sinv_size + rhs_sinv_size); +// Returns num_gemms from the first non-empty per-tensor group_sizes buffer, +// falling back to the element count of alpha for the uniform-batch case. +size_t grouped_gemm_num_gemms(Buffer_Type const &lhs_first_dims, Buffer_Type const &lhs_last_dims, + Buffer_Type const &rhs_first_dims, Buffer_Type const &rhs_last_dims, + Buffer_Type const &out_first_dims, Buffer_Type const &out_last_dims, + Buffer_Type const &alpha) { + if (lhs_first_dims.element_count() > 0) { + return lhs_first_dims.dimensions()[0]; + } else if (lhs_last_dims.element_count() > 0) { + return lhs_last_dims.dimensions()[0]; + } else if (rhs_first_dims.element_count() > 0) { + return rhs_first_dims.dimensions()[0]; + } else if (rhs_last_dims.element_count() > 0) { + return rhs_last_dims.dimensions()[0]; + } else if (out_first_dims.element_count() > 0) { + return out_first_dims.dimensions()[0]; + } else if (out_last_dims.element_count() > 0) { + return out_last_dims.dimensions()[0]; + } else { + return alpha.element_count(); // uniform batch: no ragged tensor } - auto swizzled_lhs_sinv_ptr = cublas_workspace_ptr + workspace_size; - swizzled_lhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_lhs_sinv_ptr); - auto swizzled_rhs_sinv_ptr = swizzled_lhs_sinv_ptr + lhs_sinv_size; - swizzled_rhs_sinv_ptr = move_ptr_to_next_256B_aligned(swizzled_rhs_sinv_ptr); - auto lhs_scatter_aligned_ptr = swizzled_lhs_sinv_ptr; // Already 256B aligned - auto rhs_scatter_aligned_ptr = lhs_scatter_aligned_ptr + num_gemms * tensor_scaling_sinv_aligment; +} - size_t lhs_dtype_bytes = te_dtype_bytes(lhs_dtype); - size_t rhs_dtype_bytes = te_dtype_bytes(rhs_dtype); - size_t lhs_sinv_dtype_bytes = te_dtype_bytes(lhs_sinv_dtype); - size_t rhs_sinv_dtype_bytes = te_dtype_bytes(rhs_sinv_dtype); - size_t bias_dtype_bytes = te_dtype_bytes(bias_dtype); - size_t out_dtype_bytes = te_dtype_bytes(out_dtype); +} // namespace jax +} // namespace transformer_engine - NVTE_CHECK(lhs_dtype_bytes == rhs_dtype_bytes, "sizeof(lhs_dtype) != sizeof(rhs_dtype)"); - NVTE_CHECK(lhs_sinv_dtype_bytes == rhs_sinv_dtype_bytes, - "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); +namespace transformer_engine { +namespace jax { - size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); - size_t actual_lhs_size = product(lhs_data.dimensions()); - size_t actual_rhs_size = product(rhs_data.dimensions()); - size_t actual_out_size = product(output->dimensions()); - NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", - expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, - "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, - " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, "Unexpected output size! Expect m * n = ", m, - " * ", n, " = ", expected_out_size, ", got ", actual_out_size); - } else { - NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect k * n = ", k, - " * ", n, " = ", expected_rhs_size, ", got ", actual_rhs_size); - NVTE_CHECK(expected_out_size == actual_out_size, - "Unexpected output size! Expect num_gemms * m * n = ", num_gemms, " * ", m, " * ", n, - " = ", expected_out_size, ", got ", actual_out_size); - } +// This FFI is EXPERIMENTAL and subject to change without deprecation, intended for use in JAX's internal implementation of grouped GEMM. +Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, + Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type alpha, Buffer_Type beta, Result_Type output, + Result_Type cublas_workspace, Result_Type setup_workspace, + Result_Type int64_workspace, GroupedGemmV2Config config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, lhs_axis_boundary, rhs_axis_boundary] = config; - auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); - bool grad = false; - bool accumulate = false; - bool use_split_accumulator = false; - auto bias_shape = std::vector{has_bias ? n : 0}; - const int arch = cuda::sm_arch(); + NVTE_CHECK(scaling_mode == JAXX_Scaling_Mode::NO_SCALING || + scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING, + "Only NO_SCALING and MXFP8_1D_SCALING are supported in the V2 grouped GEMM."); - if (arch < 100 && is_fp8_gemm) { - NVTE_CHECK(!lhs_is_trans && rhs_is_trans, - "For SM90 or older archs and FP8 input, only NT (row-major) GEMM is supported, ", - "got lhs_is_trans=", lhs_is_trans, ", rhs_is_trans=", rhs_is_trans); - } + const bool is_mxfp8 = scaling_mode == JAXX_Scaling_Mode::MXFP8_1D_SCALING; + size_t num_gemms = grouped_gemm_num_gemms(lhs_first_dims, lhs_last_dims, rhs_first_dims, + rhs_last_dims, out_first_dims, out_last_dims, alpha); + + // Workspaces. + auto setup_workspace_ptr = reinterpret_cast(setup_workspace->untyped_data()); + auto cublas_workspace_ptr = reinterpret_cast(cublas_workspace->untyped_data()); + cublas_workspace_ptr = move_ptr_to_next_256B_aligned(cublas_workspace_ptr); + auto workspace_size = product(cublas_workspace->dimensions()) - 256; TensorWrapper workspace_setup(setup_workspace_ptr, std::vector{product(setup_workspace->dimensions())}, DType::kByte); @@ -763,59 +822,36 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty std::vector{num_gemms}, convert_ffi_datatype_to_te_dtype(beta.element_type())); - if (is_grouped_dense_wgrad) { - NVTE_CHECK(lhs_is_trans && !rhs_is_trans, - "For grouped dense wgrad, only TN GEMM is supported in TE/JAX currently."); - - //// RHS - NVTEShape rhsShape{.data = {k, n}, .ndim = 2}; - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - rhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - - //// LHS - NVTEShape lhsShape{.data = {k, m}, .ndim = 2}; - lhs_is_trans = true; - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); - - //// OUTPUT - NVTEShape outShape{.data = {num_gemms * m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - - nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, - alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), - workspace_cublas.data(), - nullptr, // config (use defaults) - stream); - - return ffi_with_cuda_error_check(); - } - - // Nominal case for FWD or DGRAD - - //// RHS - NVTEShape rhsShape{.data = {num_gemms * k, n}, .ndim = 2}; - if (rhs_is_trans) { - rhsShape.data[0] = num_gemms * n; - rhsShape.data[1] = k; - } - auto rhs_tensor = make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, num_gemms, rhsShape); - - //// LHS - NVTEShape lhsShape{.data = {m, k}, .ndim = 2}; - if (lhs_is_trans) { - std::swap(lhsShape.data[0], lhsShape.data[1]); - } - auto lhs_tensor = make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, num_gemms, lhsShape); - lhs_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, - lhs_is_trans ? kNVTEGroupedLastDims : kNVTEGroupedFirstDims); - - //// OUTPUT - NVTEShape outShape{.data = {m, n}, .ndim = 2}; - auto out_tensor = make_grouped_tensor(*output, std::nullopt, JAXX_Scaling_Mode::NO_SCALING, - num_gemms, outShape); - out_tensor.set_group_sizes_only(int64_sizes_ptr, num_gemms, kNVTEGroupedFirstDims); + // Build grouped tensors from XLA buffer shapes and group_sizes — no m/n/k derivation needed. + // int64_workspace is partitioned into per-ragged-buffer slots of num_gemms int64 elements each. + // int64_offset is threaded through the three make_grouped_tensor calls so each non-empty *_dims + // buffer gets its own non-aliasing slot; bounds are checked inside make_grouped_tensor. + auto *int64_base = reinterpret_cast(int64_workspace->untyped_data()); + size_t int64_capacity = int64_workspace->element_count() / sizeof(int64_t); + size_t int64_offset = 0; + + // For MXFP8: in JAX, rhs=cuBLAS_A, lhs=cuBLAS_B (swapped). + // Colwise is needed when the operand's contracting dim is NOT the last dim in its layout. + const bool rhs_use_colwise = is_mxfp8 && !rhs_is_trans; + const bool lhs_use_colwise = is_mxfp8 && lhs_is_trans; + + auto rhs_tensor = + is_mxfp8 + ? make_grouped_tensor(rhs_data, rhs_sinv, scaling_mode, rhs_use_colwise, rhs_first_dims, + rhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, rhs_axis_boundary) + : make_grouped_tensor(rhs_data, rhs_first_dims, rhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, rhs_axis_boundary); + auto lhs_tensor = + is_mxfp8 + ? make_grouped_tensor(lhs_data, lhs_sinv, scaling_mode, lhs_use_colwise, lhs_first_dims, + lhs_last_dims, int64_base, int64_capacity, int64_offset, num_gemms, + stream, lhs_axis_boundary) + : make_grouped_tensor(lhs_data, lhs_first_dims, lhs_last_dims, int64_base, int64_capacity, + int64_offset, num_gemms, stream, lhs_axis_boundary); + // Output stays NO_SCALING + auto out_tensor = make_grouped_tensor(*output, out_first_dims, out_last_dims, int64_base, + int64_capacity, int64_offset, num_gemms, stream); nvte_grouped_gemm(rhs_tensor, rhs_is_trans, lhs_tensor, lhs_is_trans, nullptr, out_tensor, alpha_tensor.data(), beta_tensor.data(), workspace_setup.data(), @@ -829,33 +865,35 @@ Error_Type GroupedGemmV2FFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Ty XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmV2Handler, GroupedGemmV2FFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes (int32) + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // alpha .Arg() // beta .Ret() // output .Ret() // cublas_workspace .Ret() // setup_workspace .Ret() // int64_workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("is_grouped_dense_wgrad"), + .Attrs(), FFI_CudaGraph_Traits); Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type lhs_sinv, Buffer_Type rhs_data, Buffer_Type rhs_sinv, Buffer_Type bias, - Buffer_Type group_sizes, Buffer_Type group_offset, Result_Type output, - Result_Type workspace, size_t m, size_t n, size_t k, bool lhs_is_trans, - bool rhs_is_trans, JAXX_Scaling_Mode scaling_mode, bool has_bias, - bool is_grouped_dense_wgrad, bool use_async_d2h_group_sizes) { + Buffer_Type lhs_first_dims, Buffer_Type lhs_last_dims, + Buffer_Type rhs_first_dims, Buffer_Type rhs_last_dims, + Buffer_Type out_first_dims, Buffer_Type out_last_dims, + Buffer_Type group_offset, Result_Type output, Result_Type workspace, + GroupedGemmConfig config) { + auto [lhs_is_trans, rhs_is_trans, scaling_mode, has_bias, use_async_d2h_group_sizes, + lhs_axis_boundary, rhs_axis_boundary] = config; // Notes on matrix layouts and transpose: // Jax uses row-major data_layout, on entering this function, each input matrix pair: // A: row-major [m, k] for N - [k, m] for T @@ -872,6 +910,61 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type int num_streams = nvte_get_num_compute_streams(); + // Determine which group_sizes buffers are active (non-empty = ragged dimension). + bool is_lhs_first_ragged = lhs_first_dims.element_count() > 0; + bool is_lhs_last_ragged = lhs_last_dims.element_count() > 0; + bool is_rhs_first_ragged = rhs_first_dims.element_count() > 0; + bool is_rhs_last_ragged = rhs_last_dims.element_count() > 0; + bool is_lhs_ragged = is_lhs_first_ragged || is_lhs_last_ragged; + bool is_rhs_ragged = is_rhs_first_ragged || is_rhs_last_ragged; + bool any_ragged = is_lhs_ragged || is_rhs_ragged; + + size_t num_gemms; + if (is_lhs_first_ragged) + num_gemms = lhs_first_dims.dimensions()[0]; + else if (is_lhs_last_ragged) + num_gemms = lhs_last_dims.dimensions()[0]; + else if (is_rhs_first_ragged) + num_gemms = rhs_first_dims.dimensions()[0]; + else if (is_rhs_last_ragged) + num_gemms = rhs_last_dims.dimensions()[0]; + else + NVTE_CHECK(false, + "GroupedGemmFFI (v1): At least one of the group size buffers must be non-empty to " + "determine num_gemms."); + + const Buffer_Type *active_gs_ptr = nullptr; + if (is_lhs_first_ragged) + active_gs_ptr = &lhs_first_dims; + else if (is_lhs_last_ragged) + active_gs_ptr = &lhs_last_dims; + else if (is_rhs_first_ragged) + active_gs_ptr = &rhs_first_dims; + else if (is_rhs_last_ragged) + active_gs_ptr = &rhs_last_dims; + + // Derive m, n, k from N-D buffer dimensions using axis_boundary. + // axis_boundary splits contracting dims from non-contracting dims. + auto lhs_dims = lhs_data.dimensions(); + auto rhs_dims = rhs_data.dimensions(); + NVTE_CHECK(lhs_dims.size() >= 2, "lhs_data must be at least 2D."); + NVTE_CHECK(rhs_dims.size() >= 2, "rhs_data must be at least 2D."); + size_t lab = static_cast(lhs_axis_boundary); + size_t rab = static_cast(rhs_axis_boundary); + // k = product of contracting dims of lhs + size_t k = lhs_is_trans ? product(lhs_dims, 0, lab) : product(lhs_dims, lab, lhs_dims.size()); + size_t m, n; + if (is_rhs_ragged) { + // wgrad: non-contracting lhs dims form M; non-contracting rhs dims form N + m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) : product(lhs_dims, 0, lab); + n = rhs_is_trans ? product(rhs_dims, 0, rab) : product(rhs_dims, rab, rhs_dims.size()); + } else { + m = lhs_is_trans ? product(lhs_dims, lab, lhs_dims.size()) + : product(lhs_dims, 0, lab); // total M (sum of group sizes) + n = rhs_is_trans ? product(rhs_dims, 0, rab) / num_gemms + : product(rhs_dims, rab, rhs_dims.size()); + } + // Inputs auto lhs_ptr = reinterpret_cast(lhs_data.untyped_data()); auto rhs_ptr = reinterpret_cast(rhs_data.untyped_data()); @@ -884,9 +977,6 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto bias_ptr = has_bias ? reinterpret_cast(bias.untyped_data()) : nullptr; auto bias_dtype = convert_ffi_datatype_to_te_dtype(bias.element_type()); - NVTE_CHECK(group_sizes.dimensions().size() == 1); - size_t num_gemms = group_sizes.dimensions()[0]; - // It is weird that TE/Common GEMM only use colwise for MXFP8 const bool is_fp8_gemm = is_fp8_dtype(lhs_dtype); const bool is_tensor_scaling = scaling_mode == JAXX_Scaling_Mode::DELAYED_TENSOR_SCALING || @@ -953,14 +1043,14 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type "sizeof(lhs_sinv_dtype) != sizeof(rhs_sinv_dtype)"); size_t expected_lhs_size = m * k; - size_t expected_rhs_size = is_grouped_dense_wgrad ? (k * n) : (num_gemms * k * n); - size_t expected_out_size = is_grouped_dense_wgrad ? (num_gemms * m * n) : (m * n); + size_t expected_rhs_size = is_rhs_ragged ? (k * n) : (num_gemms * k * n); + size_t expected_out_size = is_rhs_ragged ? (num_gemms * m * n) : (m * n); size_t actual_lhs_size = product(lhs_data.dimensions()); size_t actual_rhs_size = product(rhs_data.dimensions()); size_t actual_out_size = product(output->dimensions()); NVTE_CHECK(expected_lhs_size == actual_lhs_size, "Unexpected lhs size! Expect ", expected_lhs_size, ", got ", actual_lhs_size); - if (!is_grouped_dense_wgrad) { + if (!is_rhs_ragged) { NVTE_CHECK(expected_rhs_size == actual_rhs_size, "Unexpected rhs size! Expect num_gemms * n * k = ", num_gemms, " * ", n, " * ", k, " = ", expected_rhs_size, ", got ", actual_rhs_size); @@ -976,25 +1066,28 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type size_t dim_list_bytes = sizeof(int32_t) * num_gemms; std::vector dim_list_host(num_gemms); - size_t host_num_gemms = 0; - if (use_async_d2h_group_sizes) { - host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); - NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, - " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); - } else { - auto dim_list_ptr = reinterpret_cast(group_sizes.untyped_data()); - cudaMemcpyAsync(dim_list_host.data(), dim_list_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, - stream); - // Note: This may break cudaGraph. - cudaStreamSynchronize(stream); - } - size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); - if (!is_grouped_dense_wgrad) { - NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, - ", got sum(group_sizes)=", sum_group_sizes); - } else { - NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, - ", got sum(group_sizes)=", sum_group_sizes); + if (any_ragged) { + size_t host_num_gemms = 0; + if (use_async_d2h_group_sizes) { + host_num_gemms = GroupedGemmGetGroupSizes(stream, num_gemms, nullptr, dim_list_host.data()); + NVTE_CHECK(host_num_gemms == num_gemms, "num_gemms ", num_gemms, + " does not match the return of GroupedGemmGetGroupSizes ", host_num_gemms, "."); + } else { + NVTE_CHECK(active_gs_ptr != nullptr, "active_gs_ptr is null but any_ragged is true."); + auto gs_data_ptr = reinterpret_cast(active_gs_ptr->untyped_data()); + cudaMemcpyAsync(dim_list_host.data(), gs_data_ptr, dim_list_bytes, cudaMemcpyDeviceToHost, + stream); + // Note: This may break cudaGraph. + cudaStreamSynchronize(stream); + } + // size_t sum_group_sizes = std::accumulate(dim_list_host.begin(), dim_list_host.end(), 0); + // if (!is_rhs_ragged) { + // NVTE_CHECK(m == sum_group_sizes, "Unexpected group_sizes! M = ", m, + // ", got sum(group_sizes)=", sum_group_sizes); + // } else { + // NVTE_CHECK(k == sum_group_sizes, "Unexpected group_sizes! K = ", k, + // ", got sum(group_sizes)=", sum_group_sizes); + // } } auto num_math_sm = cuda::sm_count() - getenv("NVTE_EXT_MARGIN_SM", 0); @@ -1042,7 +1135,7 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type auto lhs_shape_i = std::vector{m_i, k}; auto rhs_shape_i = std::vector{rhs_is_trans ? n : k, rhs_is_trans ? k : n}; auto out_shape_i = std::vector{m_i, n}; - if (is_grouped_dense_wgrad) { + if (is_rhs_ragged) { size_t k_i = dim_list_host[i]; lhs_shape_i[0] = lhs_is_trans ? k_i : m; lhs_shape_i[1] = lhs_is_trans ? m : k_i; @@ -1232,24 +1325,21 @@ Error_Type GroupedGemmFFI(cudaStream_t stream, Buffer_Type lhs_data, Buffer_Type XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedGemmHandler, GroupedGemmFFI, FFI::Bind() .Ctx() // stream - .Arg() // lhs_data + .Arg() // lhs_data (2D) .Arg() // lhs_sinv - .Arg() // rhs_data + .Arg() // rhs_data (2D) .Arg() // rhs_sinv .Arg() // bias - .Arg() // group_sizes + .Arg() // lhs_first_dims (G,) or empty (0,) + .Arg() // lhs_last_dims (G,) or empty (0,) + .Arg() // rhs_first_dims (G,) or empty (0,) + .Arg() // rhs_last_dims (G,) or empty (0,) + .Arg() // out_first_dims (G,) or empty (0,) + .Arg() // out_last_dims (G,) or empty (0,) .Arg() // group_offset .Ret() // output .Ret() // workspace - .Attr("M") - .Attr("N") - .Attr("K") - .Attr("lhs_is_trans") - .Attr("rhs_is_trans") - .Attr("scaling_mode") - .Attr("has_bias") - .Attr("is_grouped_dense_wgrad") - .Attr("use_async_d2h_group_sizes")); + .Attrs()); } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/csrc/extensions/pybind.cpp b/transformer_engine/jax/csrc/extensions/pybind.cpp index 28cb39b5d1..e3bc122403 100644 --- a/transformer_engine/jax/csrc/extensions/pybind.cpp +++ b/transformer_engine/jax/csrc/extensions/pybind.cpp @@ -33,6 +33,7 @@ pybind11::dict Registrations() { // Quantization dict["te_dbias_quantize_ffi"] = EncapsulateFFI(DBiasQuantizeHandler); dict["te_grouped_quantize_ffi"] = EncapsulateFFI(GroupedQuantizeHandler); + dict["te_grouped_quantize_v2_ffi"] = EncapsulateFFI(GroupedQuantizeV2Handler); dict["te_dequantize_ffi"] = EncapsulateFFI(DequantizeHandler); // Softmax diff --git a/transformer_engine/jax/csrc/extensions/quantization.cpp b/transformer_engine/jax/csrc/extensions/quantization.cpp index c5a766f7f2..06f5906edf 100644 --- a/transformer_engine/jax/csrc/extensions/quantization.cpp +++ b/transformer_engine/jax/csrc/extensions/quantization.cpp @@ -9,6 +9,7 @@ #include "../extensions.h" #include "transformer_engine/cast.h" +#include "transformer_engine/gemm.h" #include "transformer_engine/hadamard_transform.h" #include "transformer_engine/recipe.h" #include "transformer_engine/transformer_engine.h" @@ -494,5 +495,165 @@ XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeHandler, GroupedQuantizeFFI, .Attr("q_layout") .Attr("flatten_axis")); +Error_Type GroupedQuantizeV2FFI(cudaStream_t stream, Buffer_Type inputs, Buffer_Type scale_unused, + Buffer_Type group_sizes, Result_Type rowwise_out, + Result_Type colwise_out, Result_Type rowwise_sinv, + Result_Type colwise_sinv, Result_Type int64_workspace, + JAXX_Quantize_Layout quantize_layout, int64_t flatten_axis) { + (void)scale_unused; // scale is unused for MXFP8; accepted to match V1 input arity + auto in_dtype = convert_ffi_datatype_to_te_dtype(inputs.element_type()); + auto out_dtype = convert_ffi_datatype_to_te_dtype(rowwise_out->element_type()); + auto sinv_dtype = convert_ffi_datatype_to_te_dtype(rowwise_sinv->element_type()); + + NVTE_CHECK(is_fp8_dtype(out_dtype), "Output datatype must be FP8 for GroupedQuantizeV2."); + NVTE_CHECK(sinv_dtype == DType::kFloat8E8M0, + "scale_inv must be E8M0 for MXFP8 grouped quantize."); + + auto input_dims = inputs.dimensions(); + int64_t input_ndim = input_dims.size(); + if (flatten_axis < 0) flatten_axis += input_ndim; + NVTE_CHECK(flatten_axis < input_ndim && flatten_axis > 0, "flatten_axis is out of bounds!"); + + auto m = product(input_dims, 0, flatten_axis); + auto n = product(input_dims, flatten_axis, input_ndim); + size_t n_groups = group_sizes.dimensions()[0]; + + // Workspace layout (CUDA-graph safe, all device-side): + // int64_ptr[0 .. n_groups-1] : per-group ROW counts (int64) + // int64_ptr[n_groups .. 2*n_groups] : exclusive prefix-sum offsets (n_groups+1 values) + auto *int64_ptr = reinterpret_cast(int64_workspace->untyped_data()); + auto *offsets_ptr_out = int64_ptr + n_groups; // n_groups+1 values follow group_sizes + + // non_group_m handles multi-dim tensors (e.g., kernel shape G×K×N with flatten_axis=2): + // group_sizes[i] counts "slices" along the outermost group axis (e.g., 1 per expert), + // while the kernel expects actual ROW counts (e.g., K rows per expert). + // non_group_m = product(input_dims[1..flatten_axis)) converts slice→row count. + // For the lhs case (shape M×K, flatten_axis=1), non_group_m=1 (no-op). + int64_t non_group_m = + (flatten_axis > 1) ? product(input_dims, 1, static_cast(flatten_axis)) : 1; + + // Convert int32 group_sizes to int64 row counts on device (CUDA-graph safe, no D2H). + nvte_convert_int32_to_int64_with_multiplier( + reinterpret_cast(group_sizes.untyped_data()), int64_ptr, n_groups, + non_group_m, stream); + + // Compute exclusive prefix-sum offsets on device (CUDA-graph safe, no D2H). + nvte_compute_grouped_tensor_offsets(int64_ptr, offsets_ptr_out, n_groups, static_cast(n), + stream); + + NVTEShape data_shape{}; + data_shape.data[0] = m; + data_shape.data[1] = n; + data_shape.ndim = 2; + + NVTEShape sz_shape{}; + sz_shape.ndim = 1; + sz_shape.data[0] = n_groups; + + // Offsets tensor has n_groups+1 elements (exclusive prefix sums with sentinel). + NVTEShape offsets_shape{}; + offsets_shape.ndim = 1; + offsets_shape.data[0] = n_groups + 1; + + // Build input grouped tensor (plain float data, no quantization on the input side). + NVTEGroupedTensor in_grouped = nvte_create_grouped_tensor( + get_nvte_scaling_mode(JAXX_Scaling_Mode::NO_SCALING), n_groups, data_shape); + { + NVTEBasicTensor in_data{reinterpret_cast(inputs.untyped_data()), + static_cast(in_dtype), data_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedRowwiseData, &in_data, sizeof(in_data)); + NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, + sz_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedFirstDims, &sz_tensor, sizeof(sz_tensor)); + NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), + NVTEDType::kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(in_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, + sizeof(offsets_tensor)); + } + + // Build output grouped tensor. + NVTEGroupedTensor out_grouped = nvte_create_grouped_tensor( + get_nvte_scaling_mode(JAXX_Scaling_Mode::MXFP8_1D_SCALING), n_groups, data_shape); + + // Set group sizes and offsets on output tensor (same device pointers). + { + NVTEBasicTensor sz_tensor{reinterpret_cast(int64_ptr), NVTEDType::kNVTEInt64, + sz_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedFirstDims, &sz_tensor, + sizeof(sz_tensor)); + NVTEBasicTensor offsets_tensor{reinterpret_cast(offsets_ptr_out), + NVTEDType::kNVTEInt64, offsets_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedTensorOffsets, &offsets_tensor, + sizeof(offsets_tensor)); + } + + // Rowwise output data + scale_inv. + if (is_quantize_rowwise(quantize_layout)) { + NVTEBasicTensor rw_data{reinterpret_cast(rowwise_out->untyped_data()), + static_cast(out_dtype), data_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseData, &rw_data, sizeof(rw_data)); + + auto sinv_dims = rowwise_sinv->dimensions(); + NVTEShape rw_sinv_shape{}; + rw_sinv_shape.ndim = 2; + rw_sinv_shape.data[0] = product(sinv_dims, 0, sinv_dims.size() - 1); + rw_sinv_shape.data[1] = sinv_dims.back(); + NVTEBasicTensor rw_sinv{reinterpret_cast(rowwise_sinv->untyped_data()), + static_cast(sinv_dtype), rw_sinv_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedRowwiseScaleInv, &rw_sinv, + sizeof(rw_sinv)); + } + + // Colwise output data + scale_inv. + if (is_quantize_colwise(quantize_layout)) { + NVTEBasicTensor cw_data{reinterpret_cast(colwise_out->untyped_data()), + static_cast(out_dtype), data_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseData, &cw_data, + sizeof(cw_data)); + + auto cw_sinv_dims = colwise_sinv->dimensions(); + NVTEShape cw_sinv_shape{}; + cw_sinv_shape.ndim = 2; + cw_sinv_shape.data[0] = product(cw_sinv_dims, 0, cw_sinv_dims.size() - 1); + cw_sinv_shape.data[1] = cw_sinv_dims.back(); + NVTEBasicTensor cw_sinv{reinterpret_cast(colwise_sinv->untyped_data()), + static_cast(sinv_dtype), cw_sinv_shape}; + nvte_set_grouped_tensor_param(out_grouped, kNVTEGroupedColumnwiseScaleInv, &cw_sinv, + sizeof(cw_sinv)); + } + + // Zero-initialize scale_inv buffers (mirrors V1 behaviour for MXFP8). + size_t total_rowwise_sinv_size = + is_quantize_rowwise(quantize_layout) ? product(rowwise_sinv->dimensions()) : 0; + size_t total_colwise_sinv_size = + is_quantize_colwise(quantize_layout) ? product(colwise_sinv->dimensions()) : 0; + if (total_rowwise_sinv_size > 0) + nvte_memset(rowwise_sinv->untyped_data(), 0, total_rowwise_sinv_size, stream); + if (total_colwise_sinv_size > 0) + nvte_memset(colwise_sinv->untyped_data(), 0, total_colwise_sinv_size, stream); + + nvte_group_quantize(in_grouped, out_grouped, stream); + + nvte_destroy_grouped_tensor(in_grouped); + nvte_destroy_grouped_tensor(out_grouped); + + return ffi_with_cuda_error_check(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL(GroupedQuantizeV2Handler, GroupedQuantizeV2FFI, + FFI::Bind() + .Ctx() // stream + .Arg() // inputs + .Arg() // scale (unused, for input arity match) + .Arg() // group_sizes (int32) + .Ret() // rowwise_out + .Ret() // colwise_out + .Ret() // rowwise_sinv + .Ret() // colwise_sinv + .Ret() // int64_workspace + .Attr("q_layout") + .Attr("flatten_axis"), + FFI_CudaGraph_Traits); + } // namespace jax } // namespace transformer_engine diff --git a/transformer_engine/jax/dense.py b/transformer_engine/jax/dense.py index fe02e61fc0..76c984486f 100644 --- a/transformer_engine/jax/dense.py +++ b/transformer_engine/jax/dense.py @@ -27,6 +27,7 @@ is_fp8_gemm_with_all_layouts_supported, TensorUsage, QuantizeLayout, + GroupedNoScaleTensor, ) @@ -490,7 +491,8 @@ def _grouped_dense_fwd_rule( is_colwise=False, data_layout="N", flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, + first_dims=ctx_kernel.first_dims, + last_dims=ctx_kernel.last_dims, original_shape=kernel_shape, group_axis=ctx_kernel.group_axis, ) @@ -507,7 +509,8 @@ def _grouped_dense_fwd_rule( is_colwise=True, data_layout="T", flatten_axis=ctx_kernel.flatten_axis, - group_sizes=ctx_kernel.group_sizes, + first_dims=ctx_kernel.first_dims, + last_dims=ctx_kernel.last_dims, original_shape=kernel_shape, group_axis=ctx_kernel.group_axis, ) @@ -518,15 +521,29 @@ def _grouped_dense_fwd_rule( # This is needed especially when kernel_fsdp_enabled == True AND FP8 enabled. quantizer_set.kernel.q_layout = original_quantizer_set_kernel_q_layout + if is_noop_quantizer_set: + grouped_gemm_x = GroupedNoScaleTensor( + data=grouped_gemm_x, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=grouped_gemm_x.shape, + ) + grouped_gemm_kernel = GroupedNoScaleTensor( + data=grouped_gemm_kernel, + first_dims=None, + last_dims=None, + group_axis=0, + original_shape=grouped_gemm_kernel.shape, + ) output = tex.grouped_gemm( grouped_gemm_x, grouped_gemm_kernel, - group_sizes, - contracting_dims, - bias, - precision, - preferred_element_type, - group_offset, + contracting_dims=contracting_dims, + bias=bias, + precision=precision, + preferred_element_type=preferred_element_type, + group_offset=group_offset, ) ctx = ( @@ -610,11 +627,39 @@ def _grouped_dense_bwd_rule( wgrad_x_T = ctx_x wgrad_grad = casted_grad.get_tensor(usage=TensorUsage.RHS) + if is_noop_quantizer_set: + dgrad_grad = GroupedNoScaleTensor( + data=dgrad_grad, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=dgrad_grad.shape, + ) + dgrad_kernel_T = GroupedNoScaleTensor( + data=dgrad_kernel_T, + first_dims=None, + last_dims=None, + group_axis=0, + original_shape=dgrad_kernel_T.shape, + ) + wgrad_x_T = GroupedNoScaleTensor( + data=wgrad_x_T, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=wgrad_x_T.shape, + ) + wgrad_grad = GroupedNoScaleTensor( + data=wgrad_grad, + first_dims=group_sizes, + last_dims=None, + group_axis=0, + original_shape=wgrad_grad.shape, + ) dgrad = tex.grouped_gemm( dgrad_grad, dgrad_kernel_T, - group_sizes, - dgrad_contracting_dims, + contracting_dims=dgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, @@ -623,8 +668,7 @@ def _grouped_dense_bwd_rule( wgrad = tex.grouped_gemm( wgrad_x_T, wgrad_grad, - group_sizes, - wgrad_contracting_dims, + contracting_dims=wgrad_contracting_dims, precision=precision, preferred_element_type=preferred_element_type, group_offset=group_offset, diff --git a/transformer_engine/jax/flax/module.py b/transformer_engine/jax/flax/module.py index 31ce6e72e9..17c9a242f0 100644 --- a/transformer_engine/jax/flax/module.py +++ b/transformer_engine/jax/flax/module.py @@ -16,6 +16,9 @@ from jax import random as jax_random from jax.ad_checkpoint import checkpoint_name +from transformer_engine.common.recipe import ( + MXFP8BlockScaling, +) from ..dense import dense, grouped_dense @@ -1358,7 +1361,12 @@ def kernel_1_init(key, num_kernels, stack_axis, *init_args): return out, ln_output # Output, layer_norm_output -def wrap_function_in_te_state_module(f, quantization_recipe, name: Optional[str] = None): +def wrap_function_in_te_state_module( + f, + quantization_recipe, + name: Optional[str] = None, + quantization_checkpoint_name: Optional[str] = None, +): """Wraps the given function `f` to support TransformerEngine quantization. This method does a couple things: @@ -1386,6 +1394,7 @@ def generate_quantizer_set(self, postfix: str = "", n_groups: int = None): return super().generate_quantizer_set( postfix=postfix, variable_collection=OVERWRITE_WITH_GRADIENT, + quantization_checkpoint_name=quantization_checkpoint_name, fp8_recipe=quantization_recipe, n_groups=n_groups, ) @@ -1443,10 +1452,15 @@ def te_dot_general(generate_quantizer_set, x, kernel, dims, **kwargs): return wrap_function_in_te_state_module(te_dot_general, quantization_recipe, "dot_general") -def make_grouped_dense_cls(quantization_recipe): +def make_grouped_dense_cls(quantization_recipe, quantization_checkpoint_name: Optional[str] = None): """Creates a grouped dense (grouped GEMM) instance for use with TE state module.""" if quantization_recipe is not None: - raise ValueError("Ragged dot grouped GEMM does not support quantization yet") + allowed_grouped_gemm_recipes = [MXFP8BlockScaling] + assert any(isinstance(quantization_recipe, r) for r in allowed_grouped_gemm_recipes), ( + "Only the following quantization recipes are supported for grouped GEMM or `None` for" + f" BF16 without quantization: {allowed_grouped_gemm_recipes}. Got" + f" {type(quantization_recipe)}." + ) def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwargs): del kwargs # Unused @@ -1463,5 +1477,8 @@ def te_grouped_dot_general(generate_quantizer_set, x, kernel, group_sizes, **kwa return out return wrap_function_in_te_state_module( - te_grouped_dot_general, quantization_recipe, "ragged_dot" + te_grouped_dot_general, + quantization_recipe, + "ragged_dot", + quantization_checkpoint_name=quantization_checkpoint_name, )() diff --git a/transformer_engine/jax/permutation.py b/transformer_engine/jax/permutation.py index 6a0a3229d9..2732c4acc5 100644 --- a/transformer_engine/jax/permutation.py +++ b/transformer_engine/jax/permutation.py @@ -497,8 +497,9 @@ def _token_combine_bwd_rule( hidden_size, ) # The backward kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized (NaN) values - replace with zeros. - inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + # Padded positions may contain uninitialized values (NaN, inf, or garbage). + # Replace any non-finite values with zeros. + inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) else: inp_grad, merging_probs_grad = unpermute_bwd_with_merging_probs( output_grad, @@ -527,8 +528,9 @@ def _token_combine_bwd_rule( align_size=128, # Default, sizes already computed in forward ) # The permute kernel only writes to positions that tokens map to. - # Padded positions may contain uninitialized (NaN) values - replace with zeros. - inp_grad = jnp.where(jnp.isnan(inp_grad), 0.0, inp_grad) + # Padded positions may contain uninitialized values (NaN, inf, or garbage). + # Replace any non-finite values with zeros. + inp_grad = jnp.where(jnp.isfinite(inp_grad), inp_grad, 0.0) else: inp_grad, _ = permute_with_mask_map( output_grad, diff --git a/transformer_engine/jax/quantize/dequantizer.py b/transformer_engine/jax/quantize/dequantizer.py index 74787b9308..5075f1a664 100644 --- a/transformer_engine/jax/quantize/dequantizer.py +++ b/transformer_engine/jax/quantize/dequantizer.py @@ -275,7 +275,17 @@ def _grouped_dequantize(grouped_scaled_tensor): """ data = grouped_scaled_tensor.data scale_inv = grouped_scaled_tensor.scale_inv - group_sizes = grouped_scaled_tensor.group_sizes + group_sizes = ( + grouped_scaled_tensor.first_dims + if grouped_scaled_tensor.first_dims is not None + and grouped_scaled_tensor.first_dims.size > 0 + else grouped_scaled_tensor.last_dims + ) + # For non-ragged groups (kernel case), group_sizes is not stored; derive from original_shape + if group_sizes is None: + group_sizes = jnp.ones( + grouped_scaled_tensor.original_shape[grouped_scaled_tensor.group_axis], dtype=jnp.int32 + ) flatten_axis = grouped_scaled_tensor.flatten_axis scaling_mode = grouped_scaled_tensor.scaling_mode original_shape = grouped_scaled_tensor.original_shape diff --git a/transformer_engine/jax/quantize/quantizer.py b/transformer_engine/jax/quantize/quantizer.py index f5ca6aeaed..55dd7f5618 100644 --- a/transformer_engine/jax/quantize/quantizer.py +++ b/transformer_engine/jax/quantize/quantizer.py @@ -948,7 +948,7 @@ def _create_grouped_tensor_from_tensor_list( is_colwise=tensor_list[0].is_colwise, data_layout=tensor_list[0].data_layout, flatten_axis=tensor_list[0].flatten_axis, - group_sizes=group_sizes, + first_dims=group_sizes, original_shape=original_shape, group_axis=group_axis, ) diff --git a/transformer_engine/jax/quantize/tensor.py b/transformer_engine/jax/quantize/tensor.py index c26cb8a531..4b604502b0 100644 --- a/transformer_engine/jax/quantize/tensor.py +++ b/transformer_engine/jax/quantize/tensor.py @@ -9,7 +9,7 @@ rowwise and colwise quantization modes with proper scaling and dequantization. """ from dataclasses import dataclass -from typing import Callable, Tuple +from typing import Callable, Optional, Tuple from abc import ABC, abstractmethod import jax.numpy as jnp @@ -32,6 +32,7 @@ "ScaledTensor1x", "ScaledTensor2x", "GroupedScaledTensor1x", + "GroupedNoScaleTensor", "ScaledTensorFactory", "with_sharding_constraint_by_logical_axes", ] @@ -365,12 +366,14 @@ class GroupedScaledTensor1x(ScaledTensor1x): where elements are grouped along a specified axis. Attributes: - group_sizes: Array containing the size of each group + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged original_shape: The original shape of the tensor before grouping group_axis: The axis along which grouping is performed (default: 0) """ - group_sizes: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] original_shape: Tuple group_axis: int @@ -379,7 +382,7 @@ def __init__( data, scale_inv, amax, - group_sizes, + first_dims, scaling_mode, dq_dtype, _dq_func, @@ -388,9 +391,11 @@ def __init__( flatten_axis, original_shape, group_axis=0, + last_dims=None, ): self.flatten_axis = flatten_axis - self.group_sizes = group_sizes + self.first_dims = first_dims + self.last_dims = last_dims self.original_shape = original_shape self.group_axis = group_axis # TODO(Phuong):Handle RHT for grouped quantization once grouped quantization supports NVFP4 @@ -407,6 +412,18 @@ def __init__( has_rht_applied=False, ) + @property + def group_sizes(self) -> jnp.ndarray: + """Per-group sizes along the group axis. + + When first_dims is set (ragged groups), returns first_dims. + When first_dims is None (equal-sized groups), returns an array of ones with + length equal to the number of groups. + """ + if self.first_dims is not None and self.first_dims.size > 0: + return self.first_dims + return jnp.ones((self.original_shape[self.group_axis],), dtype=jnp.int32) + def __post_init__(self): assert self.scale_inv.ndim == 1, "Only support flattened scale_inv" assert self.data.ndim == 1, "Only support flattened data" @@ -422,9 +439,19 @@ def __post_init__(self): 0 <= self.group_axis < data_ndim ), f"group_axis {self.group_axis} is out of bounds for shape {self.original_shape}" + active_dims = ( + self.first_dims + if self.first_dims is not None and self.first_dims.size > 0 + else self.last_dims + ) + if active_dims is not None: + num_groups = active_dims.size + else: + num_groups = self.original_shape[self.group_axis] + expected_scale_shape = self.scaling_mode.get_grouped_scale_shape( self.original_shape, - self.group_sizes.size, + num_groups, self.group_axis, self.is_colwise, is_padded=True, @@ -442,7 +469,7 @@ def tree_flatten(self): Returns: A tuple containing (children, aux_data) for tree operations """ - children = (self.data, self.scale_inv, self.amax, self.group_sizes) + children = (self.data, self.scale_inv, self.amax, self.first_dims, self.last_dims) aux_data = ( self.scaling_mode, self.dq_dtype, @@ -455,6 +482,36 @@ def tree_flatten(self): ) return (children, aux_data) + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the tensor from its flattened representation.""" + data, scale_inv, amax, first_dims, last_dims = children + ( + scaling_mode, + dq_dtype, + _dq_func, + is_colwise, + data_layout, + flatten_axis, + original_shape, + group_axis, + ) = aux_data + return cls( + data=data, + scale_inv=scale_inv, + amax=amax, + first_dims=first_dims, + last_dims=last_dims, + scaling_mode=scaling_mode, + dq_dtype=dq_dtype, + _dq_func=_dq_func, + is_colwise=is_colwise, + data_layout=data_layout, + flatten_axis=flatten_axis, + original_shape=original_shape, + group_axis=group_axis, + ) + def apply_sharding_constraint_by_logical_axes(self, logical_axis_names: Tuple[str, ...]): raise NotImplementedError @@ -473,6 +530,52 @@ def checkpoint(self, quantizer): return jax_checkpoint_name(self, name=quantizer.checkpoint_name) +@register_pytree_node_class +@dataclass +class GroupedNoScaleTensor: + """Unquantized grouped tensor. + + Stores N-D data with per-group dimension sizes so that grouped_gemm() + can extract first/last dims automatically without explicit parameters. + + Attributes: + data: The raw (unquantized) tensor data in N-D layout + first_dims: Per-group sizes of the first (row) 2D dim, or None if not ragged + last_dims: Per-group sizes of the last (col) 2D dim, or None if not ragged + group_axis: Which axis of original_shape is the group batch prefix + original_shape: Shape of data (same as data.shape for N-D unquantized) + """ + + data: jnp.ndarray + first_dims: Optional[jnp.ndarray] + last_dims: Optional[jnp.ndarray] + group_axis: int + original_shape: Tuple + + def tree_flatten(self): + """Flattens the tensor for JAX tree operations.""" + children = (self.data, self.first_dims, self.last_dims) + aux_data = (self.group_axis, self.original_shape) + return (children, aux_data) + + @classmethod + def tree_unflatten(cls, aux_data, children): + """Reconstructs the tensor from its flattened representation.""" + group_axis, original_shape = aux_data + data, first_dims, last_dims = children + return cls( + data=data, + first_dims=first_dims, + last_dims=last_dims, + group_axis=group_axis, + original_shape=original_shape, + ) + + def dequantize(self): + """No-op dequantization — returns the raw data.""" + return self.data + + @register_pytree_node_class @dataclass class ScaledTensor2x(AbstractBaseTensor, ScaledTensor): @@ -570,7 +673,8 @@ def create_1x( is_colwise=False, data_layout="N", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, group_axis=0, has_rht_applied=False, @@ -586,29 +690,44 @@ def create_1x( is_colwise: Whether to use column-wise quantization (default: False) data_layout: The data_layout specification (default: "N") flatten_axis: The quantization axis for the tensor - group_sizes: Array of ints containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) has_rht_applied: Whether the tensor had the Randomized Hadamard Transform (RHT) applied during quantization (default: False) Returns: - A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether group_sizes is provided + A ScaledTensor1x or GroupedScaledTensor1x instance depending on whether first_dims or last_dims is provided """ if amax is None: amax = jnp.empty((1,), dtype=jnp.float32) dequantizer = ScalingModeToDequantizerMap.get(scaling_mode) - if group_sizes is not None: - flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + if ( + first_dims is not None + or last_dims is not None + or (original_shape is not None and group_axis is not None) + ): assert ( original_shape is not None ), "original_shape is not given for GroupedScaledTensor1x" + flatten_axis = (len(original_shape) + flatten_axis) % len(original_shape) + + # Determine num_groups from whichever dims array is provided, or from original_shape + active_dims = ( + first_dims if first_dims is not None and first_dims.size > 0 else last_dims + ) + if active_dims is not None: + num_groups = active_dims.size + else: + norm_group_axis = (len(original_shape) + group_axis) % len(original_shape) + num_groups = original_shape[norm_group_axis] # Handling attrs of transposed tensors group_axis = (len(original_shape) + group_axis) % len(original_shape) if data_layout == "T": - if original_shape[0] == group_sizes.size: + if original_shape[0] == num_groups: original_shape = ( original_shape[0], *original_shape[flatten_axis:], @@ -633,7 +752,8 @@ def create_1x( is_colwise=is_colwise, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, ) @@ -668,7 +788,8 @@ def create_2x( dq_dtype=jnp.bfloat16, data_layout="NN", flatten_axis=-1, - group_sizes=None, + first_dims=None, + last_dims=None, original_shape=None, group_axis=0, rowwise_has_rht_applied=False, @@ -686,7 +807,8 @@ def create_2x( dq_dtype: The data type for dequantized values (default: bfloat16) data_layout: The data_layout specification (default: "NN") flatten_axis: The quantization axis for the tensor - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -710,7 +832,8 @@ def create_2x( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied, @@ -724,7 +847,8 @@ def create_2x( is_colwise=True, data_layout=data_layout[1], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, @@ -744,7 +868,8 @@ def create( data_layout: str = "NN", q_layout: QuantizeLayout = QuantizeLayout.ROWWISE, flatten_axis: int = -1, - group_sizes: jnp.ndarray = None, + first_dims: jnp.ndarray = None, + last_dims: jnp.ndarray = None, original_shape: Tuple[int] = None, group_axis: int = 0, rowwise_has_rht_applied: bool = False, @@ -762,7 +887,8 @@ def create( data_layout: The data_layout specification (default: "NN") q_layout: The quantization axis (default: ROWWISE) flatten_axis: The axis along which the tensor could be flattened to 2D (default: -1) - group_sizes: Array containing the size of each group (default: None) + first_dims: Per-group sizes of the first (row) 2D dim (default: None) + last_dims: Per-group sizes of the last (col) 2D dim (default: None) original_shape: The original shape of the tensor before grouping (default: None) group_axis: The axis along which grouping is performed (default: 0) rowwise_has_rht_applied: Whether the row-wise tensor uses the Randomized Hadamard Transform (RHT) (default: False) @@ -785,7 +911,8 @@ def create( dq_dtype, data_layout=data_layout, flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, rowwise_has_rht_applied=rowwise_has_rht_applied, @@ -802,7 +929,8 @@ def create( is_colwise=True, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=colwise_has_rht_applied, @@ -817,7 +945,8 @@ def create( is_colwise=False, data_layout=data_layout[0], flatten_axis=flatten_axis, - group_sizes=group_sizes, + first_dims=first_dims, + last_dims=last_dims, original_shape=original_shape, group_axis=group_axis, has_rht_applied=rowwise_has_rht_applied,