Skip to content

[Common][JAX] Add CUB TopK MaxPairs interface#2784

Open
huanghua1994 wants to merge 2 commits intoNVIDIA:mainfrom
huanghua1994:CUB-topk
Open

[Common][JAX] Add CUB TopK MaxPairs interface#2784
huanghua1994 wants to merge 2 commits intoNVIDIA:mainfrom
huanghua1994:CUB-topk

Conversation

@huanghua1994
Copy link
Collaborator

Description

This PR introduces the new CUB TopK API for large N and K values.

Type of change

  • Documentation change (change only to the documentation, either a fix or a new content)
  • Bug fix (non-breaking change which fixes an issue)
  • New feature (non-breaking change which adds functionality)
  • Breaking change (fix or feature that would cause existing functionality to not work as expected)
  • Infra/Build change
  • Code refactoring

Changes

  • Added 3rdparty/cccl as a dependency since the CTK on the machine might not be new enough
  • Added transformer_engine/common/util/cub.cu as the enter point to the CUB TopK function
  • Added JAX FFI interface and JAX tests

Checklist:

  • I have read and followed the contributing guidelines
  • The functionality is complete
  • I have commented my code, particularly in hard-to-understand areas
  • I have made corresponding changes to the documentation
  • My changes generate no new warnings
  • I have added tests that prove my fix is effective or that my feature works
  • New and existing unit tests pass locally with my changes

huanghua1994 and others added 2 commits March 19, 2026 18:36
@greptile-apps
Copy link
Contributor

greptile-apps bot commented Mar 20, 2026

Greptile Summary

This PR introduces a new CUB DeviceTopK::MaxPairs wrapper for large-N/large-K top-K selection, exposed as a JAX FFI custom call (te_cub_topk_ffi). It adds the NVIDIA CCCL library as a vendored submodule to ensure access to a sufficiently new CUB API, wires up the full stack from CUDA kernel through C++ FFI handler to a JAX primitive, and provides a parametrised test suite.

Key items to address before merging:

  • Missing batcher, partition, and shardy_sharding_rule in CubTopkPrimitiveBasePrimitive declares these as abstract; omitting them means vmap and multi-device/sharding usage will silently use the un-callable abstract fallback and raise at runtime.
  • Hardcoded 4 MiB workspace with no runtime guardget_cub_topk_workspace_bytes() is documented as a heuristic sufficient only for N ≤ 5M and K ≤ 100K, but there is no check preventing callers from exceeding these bounds. CUB will silently write out-of-bounds on the GPU if the workspace is undersized. The safe pattern is to call DeviceTopK::MaxPairs(nullptr, temp_bytes, ...) first to query the actual required size.
  • No k ≤ num_items validation in CubTopkFFI — if a caller passes k > num_items, the behaviour is undefined at the CUB level.
  • cub_topk omitted from __all__ — the user-facing function should be exported alongside CubTopkPrimitive.

Confidence Score: 2/5

  • Not safe to merge — the missing primitive methods break vmap/sharding, and the hardcoded workspace size can silently corrupt GPU memory for out-of-range inputs.
  • The CUDA dispatch logic and FFI wiring are structurally sound, and the test suite covers happy-path scenarios well. However, two P1 issues affect correctness: (1) the abstract batcher/partition methods are never overridden, causing silent failures in vmap and multi-device contexts; and (2) the fixed workspace heuristic has no guard, risking silent GPU OOB writes for inputs exceeding the documented limits. An additional missing bounds check (k ≤ num_items) adds a third correctness gap.
  • transformer_engine/jax/cpp_extensions/cub.py (missing abstract method implementations and workspace guard) and transformer_engine/jax/csrc/extensions/cub.cpp (missing k ≤ num_items check).

Important Files Changed

Filename Overview
transformer_engine/common/util/cub.cu Core CUDA entry point that dispatches cub::DeviceTopK::MaxPairs for float32/float16/bfloat16 key types with int32 values; logic is sound but the caller-supplied workspace is not validated against the actual required size.
transformer_engine/jax/cpp_extensions/cub.py JAX FFI primitive wrapper — missing required batcher, partition, and shardy_sharding_rule implementations from BasePrimitive; hardcoded 4 MiB workspace has no runtime guard; cub_topk not in __all__.
transformer_engine/jax/csrc/extensions/cub.cpp C++ FFI handler that validates dtypes and shapes before delegating to nvte_cub_topk; missing a k <= num_items bounds check that could allow undefined CUB behaviour.
transformer_engine/common/include/transformer_engine/cub.h Clean public C API header declaring nvte_cub_topk; [in,out] docs on output-only tensors are slightly misleading but otherwise correct.
transformer_engine/common/CMakeLists.txt Adds CCCL submodule as a header-only dependency with existence check and correctly replaces the system CTK CCCL include paths with the vendored ones.
tests/jax/test_custom_call_compute.py New TestCubOps class validates topk correctness against jax.lax.top_k with sorted comparison and a boundary check; test coverage is good across dtypes and sizes up to N=5M.

Flowchart

%%{init: {'theme': 'neutral'}}%%
flowchart TD
    A["JAX: cub_topk call"] --> B["CubTopkPrimitive bind"]
    B --> C["FFI lowering via te_cub_topk_ffi"]
    C --> D["CubTopkFFI: validate dtypes and shapes"]
    D --> E["nvte_cub_topk: dispatch on dtype"]
    E --> F["cub DeviceTopK MaxPairs on GPU"]
    F --> G["Return top-k keys and indices"]
Loading

Last reviewed commit: "[pre-commit.ci] auto..."

Comment on lines +27 to +94
class CubTopkPrimitive(BasePrimitive):
"""
CUB Topk Primitive
"""

name = "te_cub_topk_ffi"
multiple_results = True
impl_static_args = (2,) # k_value
inner_primitive = None
outer_primitive = None

@staticmethod
def abstract(
in_keys_aval,
in_values_aval,
*,
k_value,
):
keys_dtype = dtypes.canonicalize_dtype(in_keys_aval.dtype)
values_dtype = dtypes.canonicalize_dtype(in_values_aval.dtype)
assert keys_dtype in [jnp.float32, jnp.float16, jnp.bfloat16]
assert values_dtype == jnp.int32

workspace_bytes = get_cub_topk_workspace_bytes()
out_keys_aval = jax.core.ShapedArray(shape=(k_value,), dtype=keys_dtype)
out_values_aval = jax.core.ShapedArray(shape=(k_value,), dtype=jnp.int32)
workspace_aval = jax.core.ShapedArray(shape=(workspace_bytes,), dtype=jnp.uint8)
return (out_keys_aval, out_values_aval, workspace_aval)

@staticmethod
def outer_abstract(*args, **kwargs):
out_keys_aval, out_values_aval, _workspace_aval = CubTopkPrimitive.abstract(*args, **kwargs)
return (out_keys_aval, out_values_aval)

@staticmethod
def lowering(
ctx,
in_keys,
in_values,
k_value,
):
workspace_bytes = get_cub_topk_workspace_bytes()
return ffi.ffi_lowering(
CubTopkPrimitive.name,
)(
ctx,
in_keys,
in_values,
k_value=k_value,
workbuf_bytes=workspace_bytes,
)

@staticmethod
def impl(
in_keys,
in_values,
k_value,
):
assert CubTopkPrimitive.inner_primitive is not None
out_keys, out_values, _workspace = CubTopkPrimitive.inner_primitive.bind(
in_keys,
in_values,
k_value=k_value,
)
return (out_keys, out_values)


register_primitive(CubTopkPrimitive)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Missing batcher, partition, and shardy_sharding_rule methods

CubTopkPrimitive extends BasePrimitive, which declares batcher(), partition(), and shardy_sharding_rule() as abstract methods. CubTopkPrimitive does not implement any of them.

When register_primitive(CubTopkPrimitive) is called, base.py does:

batching.primitive_batchers[outer_p] = cls.batcher  # resolves to abstract method → returns NotImplemented
outer_p_lower.def_partition(partition=cls.partition, ...)  # same

This means:

  • Any attempt to use vmap over cub_topk will fail at runtime because the registered batcher is the abstract method (which returns NotImplemented, not a callable that returns batched results).
  • Multi-device / sharding via custom_partitioning will similarly fail when partition is invoked.

Every other primitive in the codebase (e.g., in router.py) implements all three of these methods. If sharding and batching are intentionally unsupported for now, the methods should at minimum raise a clear NotImplementedError (rather than silently returning NotImplemented), and this limitation should be documented.

Comment on lines +17 to +24
def get_cub_topk_workspace_bytes() -> int:
"""
Get the workspace size for CUB Topk
The safe way is calling the CUB kernel to query the workspace size.
For convenience, we use a heuristic value based on experiments.
4 MiB is enough for N up to 5,000,000 and K up to 100,000.
"""
return 4 * 1024 * 1024
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Hardcoded workspace size may silently corrupt memory for large inputs

get_cub_topk_workspace_bytes() always returns a fixed 4 MiB and the docstring itself acknowledges this only covers "N up to 5,000,000 and K up to 100,000." However, there is no validation in the Python or C++ layer that the user's actual N and K do not exceed these limits.

If a caller passes N > 5_000_000 or K > 100_000, cub::DeviceTopK::MaxPairs will be given an undersized workspace buffer and will write out-of-bounds on the GPU — a silent CUDA memory corruption with no error raised back to the caller.

The correct approach is to call cub::DeviceTopK::MaxPairs(nullptr, temp_storage_bytes, ...) with a null workspace pointer to query the required size at runtime, then allocate that exact amount. The current heuristic should at minimum be accompanied by a runtime guard that raises an error when the inputs exceed the documented limits.

Comment on lines +39 to +40
int num_items = static_cast<int>(keys_in_shape[0]);
int k = static_cast<int>(k_value);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 No validation that k <= num_items

There is no check that k_value is less than or equal to num_items (the size of the input array). CUB's DeviceTopK::MaxPairs requires k <= num_items; if k > num_items the behavior is undefined and will likely produce a CUDA error or garbage output.

A guard should be added here alongside the existing shape checks:

NVTE_CHECK(k <= num_items, "k (", k, ") must be <= num_items (", num_items, ")");


from .base import BasePrimitive, register_primitive

__all__ = ["CubTopkPrimitive"]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P2 Public function cub_topk not exported in __all__

The user-facing function cub_topk (defined at line 97) is not included in __all__. Only CubTopkPrimitive is listed. Tools and users relying on __all__ for the module's public API won't discover cub_topk. Since the test imports it as a primary API (from transformer_engine.jax.cpp_extensions.cub import cub_topk), it should be exported.

Suggested change
__all__ = ["CubTopkPrimitive"]
__all__ = ["CubTopkPrimitive", "cub_topk"]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant