[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749
[JAX] Grouped GEMM Refactor to use first_dims and last_dims#2749jberchtold-nvidia wants to merge 16 commits intoNVIDIA:mainfrom
Conversation
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
tensor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
Greptile SummaryThis PR replaces the single Key changes and issues found:
Confidence Score: 2/5
Important Files Changed
Flowchart%%{init: {'theme': 'neutral'}}%%
flowchart TD
A["grouped_gemm(lhs, rhs, contracting_dims, ...)"] --> B{lhs type?}
B -->|GroupedNoScaleTensor| C["scaling_mode = NO_SCALING\nlhs_first/last_dims from lhs.first_dims/last_dims"]
B -->|GroupedScaledTensor1x| D["scaling_mode = lhs.scaling_mode\nlhs_first/last_dims from lhs.first_dims/last_dims"]
C --> E{rhs type?}
D --> E
E -->|GroupedNoScaleTensor| F["rhs_first/last_dims from rhs.first/last_dims"]
E -->|GroupedScaledTensor1x| G["rhs_first/last_dims from rhs.first/last_dims\nvalidate scaling_mode match (only if lhs is also GroupedScaledTensor1x)"]
F --> H["Infer out_first/last_dims:\n• rhs ragged → wgrad path, out=empty\n• lhs_first ragged → out_first=lhs_first\n• lhs_last ragged → out_last=lhs_last"]
G --> H
H --> I["Compute lhs_is_trans, rhs_is_trans\nfrom contracting_dims + shape"]
I --> J["lhs_axis_boundary = get_lhs_axis_boundary()\nrhs_axis_boundary = get_rhs_axis_boundary()"]
J --> K{_can_use_v2_grouped_gemm?\nNO_SCALING + bf16 + SM100+}
K -->|Yes| L["V2 FFI: GroupedGemmV2FFI\nalpha/beta buffers\nint64_workspace partitioned per ragged dim"]
K -->|No| M["V1 FFI: GroupedGemmFFI\nper-group loop in C++\ngroup_sizes d2h copy"]
L --> N["nvte_grouped_gemm\n(Blackwell grouped kernel)"]
M --> O["cuBLAS per-group GEMMs\n(Hopper/older)"]
Last reviewed commit: 2b84dfd |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
35171af to
88bb7da
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
20fadc7 to
025f598
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
a427b9e to
089e530
Compare
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…mm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
for more information, see https://pre-commit.ci
|
/te-ci |
Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
…' into jberchtold/gmm-refactor Signed-off-by: Jeremy Berchtold <jberchtold@nvidia.com>
|
/te-ci |
| def _grouped_gemm_lhs_M(lhs_shape_2d: Tuple[int, int], lhs_is_trans: bool) -> int: | ||
| """Non-contracting output size M from the 2-D LHS buffer.""" | ||
| return lhs_shape_2d[1] if lhs_is_trans else lhs_shape_2d[0] | ||
|
|
||
|
|
||
| def _grouped_gemm_rhs_N(rhs_shape_2d: Tuple[int, int], rhs_is_trans: bool, num_groups: int) -> int: | ||
| """Non-contracting output size N from the 2-D RHS buffer.""" | ||
| return rhs_shape_2d[0] // num_groups if rhs_is_trans else rhs_shape_2d[1] |
There was a problem hiding this comment.
Suggest calling it lhs_non_contracting_dims and rhs_non_contracting_dims as M and N are still ambiguous.
Besides, I think we should not assume that lhs and rhs are 2D but can be N-D.
|
|
||
| Args: | ||
| lhs_data: Left-hand side input matrix data, 1D flattened array | ||
| lhs_data: Left-hand side input matrix data, 2D array [rows, cols] |
There was a problem hiding this comment.
When the LHS needs to be transposed, we won't be able to have a 2D shape.
Also, I would prefer us not to reshape/merge any axes until C++. Looking into the future, especially when we have a solution to handle the EP part, we may not need to go with shard_map anymore.
| rhs_first_dims_aval, | ||
| rhs_last_dims_aval, | ||
| out_first_dims_aval, | ||
| out_last_dims_aval, |
There was a problem hiding this comment.
Why does out_xxx_dims_aval need to be the inputs for the primitives? Can't the primitive come up with that after having other dims and contracting dims info?
There was a problem hiding this comment.
Agreed that it doesn't need to be an input the the grouped_gemm API. To avoid differing inner/outer primitive signatures, I've kept this as an arg to the primitive but am now deriving out first and last dims from the inputs inside the grouped_gemm function instead of requiring the user to specify it.
| lhs_first_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed first dim varies, else None/(0,) | ||
| lhs_last_dims: jnp.ndarray = None, # (G,) int32 if LHS squashed last dim varies, else None/(0,) | ||
| rhs_first_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed first dim varies, else None/(0,) | ||
| rhs_last_dims: jnp.ndarray = None, # (G,) int32 if RHS squashed last dim varies, else None/(0,) | ||
| out_first_dims: jnp.ndarray = None, # (G,) int32 if output first dim varies, else None/(0,) | ||
| out_last_dims: jnp.ndarray = None, # (G,) int32 if output last dim varies, else None/(0,) |
There was a problem hiding this comment.
Either the GroupedScaledTensor should carry this information, or one should be able to interpolate this from grouped_sizes + contracting_dims.
for more information, see https://pre-commit.ci
| first_dims is not None | ||
| or last_dims is not None | ||
| or (original_shape is not None and group_axis is not None) | ||
| ): |
There was a problem hiding this comment.
group_axis is not None is always True — condition is wider than intended
group_axis has a default value of 0, so group_axis is not None evaluates to True for every caller that does not explicitly pass group_axis=None. This means the third branch of the or:
or (original_shape is not None and group_axis is not None)reduces to simply original_shape is not None, which is a much broader guard than the old group_sizes is not None. Any call to ScaledTensorFactory.make_grouped(…, original_shape=shape) — even without first_dims or last_dims — now enters the grouped path and returns a GroupedScaledTensor1x with both dim arrays set to None. This silently changes the return type for callers that provided original_shape for informational purposes only, and those callers will now see num_groups derived implicitly from original_shape[group_axis] instead of receiving a plain ScaledTensor1x.
The condition should be restricted to the cases where grouping is actually requested:
| first_dims is not None | |
| or last_dims is not None | |
| or (original_shape is not None and group_axis is not None) | |
| ): | |
| if ( | |
| first_dims is not None | |
| or last_dims is not None | |
| ): |
If the "uniform grouped" case (kernel rhs without explicit per-group sizes) needs to be handled here, it should be expressed with an explicit sentinel argument rather than overloading original_shape.
|
|
||
| if isinstance(rhs, GroupedNoScaleTensor): | ||
| rhs_data = rhs.data | ||
| rhs_shape = rhs.original_shape | ||
| rhs_scale_inv = jnp.empty((0,), jnp.float32) | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| elif isinstance(rhs, GroupedScaledTensor1x): | ||
| rhs_shape = rhs.original_shape | ||
| rhs_data = rhs.data.reshape(rhs_shape) | ||
| rhs_scale_inv = rhs.scale_inv | ||
| if lhs.scaling_mode != rhs.scaling_mode: | ||
| rhs_first_dims = rhs.first_dims if rhs.first_dims is not None else empty_gs | ||
| rhs_last_dims = rhs.last_dims if rhs.last_dims is not None else empty_gs | ||
| if isinstance(lhs, GroupedScaledTensor1x) and lhs.scaling_mode != rhs.scaling_mode: | ||
| raise ValueError( | ||
| f"Mismatched scaling modes: lhs.scaling_mode={lhs.scaling_mode}," | ||
| f" rhs.scaling_mode={rhs.scaling_mode}" |
There was a problem hiding this comment.
scaling_mode left as NO_SCALING when lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x
When lhs is a GroupedNoScaleTensor, the lhs block sets scaling_mode = ScalingMode.NO_SCALING. The subsequent rhs block only overrides scaling_mode when isinstance(lhs, GroupedScaledTensor1x):
if isinstance(lhs, GroupedScaledTensor1x):
scaling_mode = lhs.scaling_mode # never executes for GroupedNoScaleTensor lhsSo if a caller passes lhs=GroupedNoScaleTensor and rhs=GroupedScaledTensor1x, scaling_mode stays NO_SCALING while rhs_scale_inv holds real scale values. C++ will then use NO_SCALING logic and ignore the rhs scales entirely, producing silently wrong numerical results rather than a clear error.
The scaling-mode consistency check that guards against mismatched GroupedScaledTensor1x pairs does not fire here either because isinstance(lhs, GroupedScaledTensor1x) is False.
Add an explicit cross-type guard early in the rhs block:
elif isinstance(rhs, GroupedScaledTensor1x):
if isinstance(lhs, GroupedNoScaleTensor):
raise TypeError(
"lhs is GroupedNoScaleTensor but rhs is GroupedScaledTensor1x; "
"both operands must use the same tensor type."
)
...|
/te-ci |
| flatten_axis, | ||
| original_shape, | ||
| group_axis=0, | ||
| last_dims=None, |
There was a problem hiding this comment.
Nit: I think last_dims and first_dims should be positioned together.
| self.first_dims = first_dims | ||
| self.last_dims = last_dims | ||
| self.original_shape = original_shape | ||
| self.group_axis = group_axis |
There was a problem hiding this comment.
Since we store first_dims and last_dims now, I think we no longer need the group_axis.
| def tree_unflatten(cls, aux_data, children): | ||
| """Reconstructs the tensor from its flattened representation.""" |
There was a problem hiding this comment.
I think we should be able to reuse the base tree_unflatten as the order is still cls(*children, *aux_data).
| first_dims=ctx_kernel.first_dims, | ||
| last_dims=ctx_kernel.last_dims, |
There was a problem hiding this comment.
I think we don't need to pass dims here, as the tensors should already carry them.
| if is_noop_quantizer_set: | ||
| grouped_gemm_x = GroupedNoScaleTensor( | ||
| data=grouped_gemm_x, | ||
| first_dims=group_sizes, | ||
| last_dims=None, | ||
| group_axis=0, | ||
| original_shape=grouped_gemm_x.shape, | ||
| ) | ||
| grouped_gemm_kernel = GroupedNoScaleTensor( | ||
| data=grouped_gemm_kernel, | ||
| first_dims=None, | ||
| last_dims=None, | ||
| group_axis=0, | ||
| original_shape=grouped_gemm_kernel.shape, | ||
| ) |
There was a problem hiding this comment.
How about making the grouped_quantize to return GroupedNoScaleTensor when the quantizer set is empty?
| out_first_dims_aval, | ||
| out_last_dims_aval, |
There was a problem hiding this comment.
But the out_xxx_dims could be the return buffers. Why should it be input buffers?
| lhs_shape = lhs_data_aval.shape | ||
| rhs_shape = rhs_data_aval.shape |
There was a problem hiding this comment.
Can't do this as the input could be both 1D.
Description
This PR refactors the grouped GEMM API in the JAX backend to support fully ragged (variable-size per group)
dimensions across all tensor axes, replacing the previous single group_sizes parameter with six per-tensor
dimension parameters. The motivation is to generalize the interface so that forward and backward (wgrad) passes
can be expressed uniformly without special-casing, and to eliminate the need for callers to manually compute and
pass matrix dimensions (M, N, K) — these are now derived automatically from XLA buffer descriptors in C++.
Addresses issue: #2648
Type of change
Changes
Please list the changes introduced in this PR:
arguments — lhs_first_dims, lhs_last_dims, rhs_first_dims, rhs_last_dims, out_first_dims, out_last_dims — each an
optional (G,) int32 array describing per-group sizes along that tensor axis (empty (0,) arrays indicate a
uniform/non-ragged dimension)
shapes inside the C++ handler, eliminating manual dimension computation in Python
arrays are non-empty (non-empty rhs_first_dims indicates a ragged K contraction dimension, producing a
(num_groups, M, N) output)
single FFI attribute struct, replacing individual attribute bindings
arrays to int64 in partitioned int64_workspace slots, and returns updated workspace offset to avoid aliasing
appropriate new per-tensor parameter (lhs_first_dims/out_first_dims for forward; rhs_first_dims for wgrad)
jnp.empty((0,), jnp.int32) sentinels for non-ragged axes
Checklist: