diff --git a/.flake8 b/.flake8 index 1656330a998..eeec63740e8 100644 --- a/.flake8 +++ b/.flake8 @@ -5,3 +5,4 @@ max-line-length = 119 # E402: module level import not at top of file per-file-ignores = __init__.py:F401,F403,E402 + fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266 diff --git a/fastdeploy/envs.py b/fastdeploy/envs.py index 05dff41a251..be3c2e688df 100644 --- a/fastdeploy/envs.py +++ b/fastdeploy/envs.py @@ -50,7 +50,7 @@ # Set attention backend. "NATIVE_ATTN", "APPEND_ATTN" # and "MLA_ATTN" can be set currently. "FD_ATTENTION_BACKEND": lambda: os.getenv("FD_ATTENTION_BACKEND", "APPEND_ATTN"), - # Set sampling class. "base", "base_non_truncated", "air" and "rejection" can be set currently. + # Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently. "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), # Set moe backend."cutlass","marlin" and "triton" can be set currently. "FD_MOE_BACKEND": lambda: os.getenv("FD_MOE_BACKEND", "cutlass"), diff --git a/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py new file mode 100644 index 00000000000..b78baaee4a5 --- /dev/null +++ b/fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py @@ -0,0 +1,939 @@ +# SPDX-License-Identifier: Apache-2.0 +# SPDX-FileCopyrightText: Copyright contributors to the vLLM project +""" +Combined Top-K and Top-P Triton kernels. + +Based on the paper "Qrita: High-performance Top-k and Top-p Algorithm for GPUs +using Pivot-based Truncation and Selection" By Park et al. +(https://arxiv.org/abs/2602.01518) + +""" + +import warnings + +import paddle +from paddle.utils.deprecated import VisibleDeprecationWarning + +# Suppress the VisibleDeprecationWarning from use_triton_in_paddle that fires +# on every Triton kernel launch (paddle.device.cuda.current_stream / +# synchronize). In serving hot-paths this produces thousands of log lines per +# second and the I/O overhead alone can cause client-visible timeouts. +warnings.filterwarnings("ignore", category=VisibleDeprecationWarning) + +import triton # noqa: E402 +import triton.language as tl # noqa: E402 + +_TRITON_TABLE_CACHE: dict[tuple[paddle.device], tuple[paddle.Tensor, paddle.Tensor]] = {} +_TRITON_BUFFER_CACHE: dict[tuple[paddle.device, paddle.dtype, int], paddle.Tensor] = {} + +# fmt: off +_NORMAL_CDF_TO_SIGMA_TABLE = [ + 3.656, 3.650, 3.650, 3.650, 3.626, 3.626, 3.626, 3.514, 3.514, 3.503, + 3.503, 3.434, 3.434, 3.428, 3.428, 3.387, 3.380, 3.380, 3.376, 3.373, + 3.373, 3.356, 3.354, 3.354, 3.291, 3.249, 3.234, 3.214, 3.198, 3.198, + 3.185, 3.177, 3.177, 3.165, 3.164, 3.161, 3.138, 3.120, 3.115, 3.113, + 3.093, 3.066, 3.054, 3.043, 3.037, 3.023, 2.993, 2.991, 2.976, 2.970, + 2.952, 2.946, 2.932, 2.908, 2.902, 2.895, 2.886, 2.874, 2.861, 2.844, + 2.836, 2.810, 2.801, 2.790, 2.784, 2.779, 2.767, 2.757, 2.745, 2.733, + 2.723, 2.716, 2.693, 2.678, 2.671, 2.656, 2.649, 2.629, 2.611, 2.595, + 2.592, 2.585, 2.574, 2.550, 2.543, 2.534, 2.521, 2.518, 2.497, 2.485, + 2.468, 2.450, 2.441, 2.430, 2.412, 2.402, 2.389, 2.383, 2.377, 2.364, + 2.349, 2.338, 2.332, 2.319, 2.310, 2.301, 2.282, 2.274, 2.266, 2.250, + 2.242, 2.236, 2.226, 2.215, 2.207, 2.196, 2.179, 2.171, 2.162, 2.147, + 2.135, 2.121, 2.109, 2.095, 2.085, 2.073, 2.063, 2.045, 2.030, 2.016, + 2.003, 1.992, 1.983, 1.972, 1.960, 1.949, 1.940, 1.928, 1.912, 1.897, + 1.881, 1.869, 1.854, 1.838, 1.824, 1.807, 1.792, 1.779, 1.764, 1.751, + 1.739, 1.726, 1.711, 1.697, 1.685, 1.668, 1.652, 1.636, 1.622, 1.603, + 1.585, 1.568, 1.551, 1.534, 1.513, 1.499, 1.480, 1.464, 1.441, 1.422, + 1.394, 1.373, 1.347, 1.320, 1.296, 1.270, 1.246, 1.219, 1.190, 1.163, + 1.135, 1.104, 1.073, 1.041, 1.006, 0.969, 0.931, 0.894, 0.851, 0.806, + 0.757, 0.702, 0.643, 0.574, 0.498, 0.405, 0.288, 0.134, -0.110, -3.813 +] + +_PERCENTILE_TO_STD_TABLE = [ + 2.576, 2.319, 2.178, 2.064, 1.968, 1.892, 1.819, 1.757, 1.708, 1.659, + 1.616, 1.568, 1.526, 1.492, 1.456, 1.420, 1.382, 1.342, 1.309, 1.280, + 1.249, 1.221, 1.193, 1.169, 1.145, 1.121, 1.095, 1.073, 1.050, 1.030, + 1.008, 0.987, 0.966, 0.945, 0.926, 0.910, 0.891, 0.871, 0.854, 0.837, + 0.819, 0.803, 0.784, 0.767, 0.753, 0.734, 0.719, 0.702, 0.690, 0.675, + 0.658, 0.640, 0.625, 0.609, 0.595, 0.578, 0.564, 0.550, 0.537, 0.521, + 0.509, 0.495, 0.481, 0.466, 0.453, 0.439, 0.424, 0.410, 0.397, 0.383, + 0.370, 0.356, 0.343, 0.330, 0.316, 0.302, 0.289, 0.274, 0.261, 0.247, + 0.235, 0.223, 0.209, 0.196, 0.184, 0.172, 0.159, 0.149, 0.137, 0.124, + 0.112, 0.100, 0.086, 0.074, 0.062, 0.050, 0.035, 0.023, 0.009, -0.003, + -0.015, -0.027, -0.039, -0.052, -0.063, -0.074, -0.085, -0.097, -0.109, -0.122, + -0.134, -0.147, -0.158, -0.171, -0.184, -0.196, -0.210, -0.223, -0.235, -0.248, + -0.261, -0.275, -0.289, -0.302, -0.317, -0.328, -0.341, -0.353, -0.368, -0.382, + -0.396, -0.410, -0.426, -0.439, -0.452, -0.465, -0.480, -0.493, -0.507, -0.521, + -0.537, -0.551, -0.568, -0.582, -0.597, -0.614, -0.628, -0.643, -0.658, -0.673, + -0.691, -0.706, -0.721, -0.738, -0.754, -0.769, -0.789, -0.808, -0.824, -0.838, + -0.857, -0.877, -0.893, -0.912, -0.929, -0.947, -0.965, -0.983, -1.003, -1.027, + -1.050, -1.070, -1.092, -1.117, -1.139, -1.162, -1.189, -1.216, -1.241, -1.272, + -1.300, -1.330, -1.367, -1.404, -1.441, -1.485, -1.523, -1.564, -1.607, -1.658, + -1.710, -1.778, -1.832, -1.901, -1.978, -2.068, -2.174, -2.325, -2.577, -3.813 +] +# fmt: on + + +@triton.jit +def _topk_topp_kernel( + LOGITS, + BUFFER, + MASK_OUT, + PERCENTILE_TO_STD_TABLE, + NORMAL_CDF_TO_SIGMA_TABLE, + K, + P, + BATCH_SIZE, + VOCAB_SIZE: tl.constexpr, + MASK_VALUE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, + BLOCK_SIZE_TRUNC: tl.constexpr, + TOPK_ENABLED: tl.constexpr, + TOPP_ENABLED: tl.constexpr, + WRITE_MASK: tl.constexpr, +): + NUM_TILES: tl.constexpr = (VOCAB_SIZE + BLOCK_SIZE - 1) // BLOCK_SIZE + pid = tl.program_id(0) + num_programs = tl.num_programs(0) + for row_id in tl.range(pid, BATCH_SIZE, num_programs): + LOGITS_ROW = LOGITS + row_id * VOCAB_SIZE + BUFFER_ROW = BUFFER + pid * VOCAB_SIZE + + final_pivot = -float("inf") + duplicate_logit = float("inf") + num_duplicate_logit = tl.zeros((), dtype=tl.uint32) + num_keep = tl.zeros((), dtype=tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + max_logit = -float("inf") + min_logit = float("inf") + + if TOPK_ENABLED: + k = tl.load(K + row_id) + if k < VOCAB_SIZE: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf")) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0)) + + # Calculate outlier pivot t for Gaussian sigma-truncation + percentile = tl.cast(k / VOCAB_SIZE * 200, tl.uint32) + percentile = tl.minimum(percentile, 199) + sigma = tl.load(PERCENTILE_TO_STD_TABLE + percentile) + sigma = sigma + tl.abs(sigma) * -0.15 + outlier_pivot = avg_logit + std_logit * sigma + num_outliers = tl.zeros((), dtype=tl.uint32) + + # First pass: compute max and min logits and gather outliers + num_finite_total = tl.zeros((), dtype=tl.uint32) + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk_mask = logits_blk > -float("inf") + finite_blk = tl.where(finite_blk_mask, logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + num_finite_total += tl.sum(finite_blk_mask & mask_n) + + outlier_mask = (logits_blk > outlier_pivot) & mask_n + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, logits_blk, mask=outlier_mask) + + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + + # Second passes: Ternary search for pivots + num_iters = 0 + k_pivot = float("inf") + k_pivots_num = tl.zeros((), dtype=tl.uint32) + min_larger = float("inf") + num_min_larger = tl.zeros((), dtype=tl.uint32) + if num_outliers > k: + max_range = max_logit + min_range = outlier_pivot + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + found_pivot = 0 + while found_pivot == 0: + k_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + + min_larger_0 = tl.minimum(min_larger_0, tl.min(logits_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(logits_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + logits_blk2 = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=-float("inf")) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + found_pivot = 1 + if k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + found_pivot = 1 + + # Update range + if k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + + num_iters += 1 + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: + k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + else: + # If top-k outlier gathering failed, search whole logit space + max_range = max_logit + min_range = min_logit + found_pivot = 0 + while found_pivot == 0: + k_pivot_0 = (max_range - min_range) * 1.0 / 4.0 + min_range + k_pivots_num_0 = tl.zeros((), dtype=tl.uint32) + min_larger_0 = float("inf") + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + k_pivot_1 = (max_range - min_range) * 2.0 / 4.0 + min_range + k_pivots_num_1 = tl.zeros((), dtype=tl.uint32) + min_larger_1 = float("inf") + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate k_pivots_num and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + k_pivots_num_0 += tl.sum(logits_blk2 > k_pivot_0) + k_pivots_num_1 += tl.sum(logits_blk2 > k_pivot_1) + + # Exclude -inf from min_larger to avoid + # poisoning the convergence check. + finite_blk2 = tl.where(logits_blk2 > -float("inf"), logits_blk2, float("inf")) + min_larger_0 = tl.minimum(min_larger_0, tl.min(finite_blk2)) + min_larger_1 = tl.minimum(min_larger_1, tl.min(finite_blk2)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk2 = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + + num_min_larger_0 += tl.sum(tl.abs(logits_blk2 - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(logits_blk2 - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if k_pivots_num_0 >= k and k_pivots_num_0 - num_min_larger_0 < k: + k_pivot = k_pivot_0 + k_pivots_num = k_pivots_num_0 + min_larger = min_larger_0 + num_min_larger = num_min_larger_0 + found_pivot = 1 + if k_pivots_num_1 >= k and k_pivots_num_1 - num_min_larger_1 < k: + k_pivot = k_pivot_1 + k_pivots_num = k_pivots_num_1 + min_larger = min_larger_1 + num_min_larger = num_min_larger_1 + found_pivot = 1 + + # Update range + if k_pivots_num_1 > k: + min_range = k_pivot_1 + elif k_pivots_num_0 > k: + min_range = k_pivot_0 + + if k_pivots_num_0 < k: + max_range = k_pivot_0 + elif k_pivots_num_1 < k: + max_range = k_pivot_1 + + num_iters += 1 + if num_iters >= 18 or tl.abs(min_range - max_range) < 1e-9: + k_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = min_larger + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - (k_pivots_num - k) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k only path. If there are fewer finite values + # than k (e.g. grammar mask), keep everything. + final_pivot = k_pivot if num_finite_total > k else -float("inf") + + if TOPP_ENABLED and num_finite_total > k: + #### TOP-P SAMPLING AFTER TOP-K #### + p = tl.load(P + row_id) + if p < 1.0: + min_logit = k_pivot + sum_exp_logits = 0.0 + num_outliers_2 = tl.zeros((), dtype=tl.uint32) + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + # Third pass: Calculate exp logits and sum, gather outliers + if num_outliers > k: + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + + outlier_mask = (probs_blk > min_logit) & mask_n_2 + + # Duplicate logit handling for Top-k + if num_keep < num_duplicate_logit: + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load( + BUFFER_ROW + offs_n, + mask=mask_n_2, + other=-float("inf"), + ) + + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + else: + # If top-k outlier gathering failed, + # retry gathering using top-k pivot + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load( + LOGITS_ROW + offs_n, + mask=mask_n, + other=-float("inf"), + ) + + outlier_mask = (probs_blk > min_logit) & mask_n + + # Duplicate logit handling for Top-k + duplicate_mask = tl.abs(probs_blk - duplicate_logit) < 1e-9 + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + outlier_mask = outlier_mask & (~duplicate_remove_mask) + num_kept += tl.sum(duplicate_keep_mask) + + probs_blk = tl.where(outlier_mask, probs_blk, -float("inf")) + probs_blk = probs_blk - max_logit + probs_blk = tl.exp(probs_blk) + sum_exp_logits += tl.sum(probs_blk) + + cumulative_pos = tl.cast( + tl.cumsum(outlier_mask) - 1 + num_outliers_2, + tl.int32, + ) + num_outliers_2 += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + search_range = tl.cast(num_outliers_2, tl.int32) + search_iters = tl.cast( + (num_outliers_2 + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + # Fourth pass: Calculate BUFFER and get outliers + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n_2) + + max_range = tl.exp(max_logit - max_logit) / sum_exp_logits + min_range = tl.exp(min_logit - max_logit) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Fifth passes: Search for p_pivot + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and (p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p): + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and (p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p): + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-k + Top-p path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_logit + + if TOPP_ENABLED and final_pivot == -float("inf"): + #### STANDALONE TOP-P SAMPLING #### + p = tl.load(P + row_id) + if p < 1.0: + # Zeroth pass: Compute avg and std from a sample block + offs = tl.arange(0, BLOCK_SIZE) + mask_n = offs < VOCAB_SIZE + logits_blk0 = tl.load(LOGITS_ROW + offs, mask=mask_n, other=-float("inf")) + # Exclude -inf values (e.g. from grammar bitmasks) from + # statistics to avoid NaN in pivot computation. + finite_mask = (logits_blk0 > -float("inf")) & mask_n + num_finite = tl.sum(finite_mask) + finite_logits = tl.where(finite_mask, logits_blk0, 0.0) + avg_logit = tl.where(num_finite > 0, tl.sum(finite_logits) / num_finite, 0.0) + sq_avg_logit = tl.where( + num_finite > 0, + tl.sum(finite_logits * finite_logits) / num_finite, + 0.0, + ) + std_logit = tl.sqrt(tl.maximum(sq_avg_logit - avg_logit * avg_logit, 0.0)) + max_sample = avg_logit + std_logit * 10.0 + sum_exp_logits = 0.0 + + # First pass: compute max and min logits and sum_exp_logits + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + max_logit = tl.maximum(max_logit, tl.max(logits_blk)) + # Exclude -inf from min to keep binary search bounds + # finite (avoids NaN pivots). + finite_blk = tl.where(logits_blk > -float("inf"), logits_blk, float("inf")) + min_logit = tl.minimum(min_logit, tl.min(finite_blk)) + + probs_blk = tl.exp(logits_blk - max_sample) + probs_blk = tl.where(mask_n, probs_blk, 0.0) + sum_exp_logits += tl.sum(probs_blk) + + # If no finite logits exist (all -inf), clamp min to + # max so the search converges to -inf (no masking). + min_logit = tl.minimum(min_logit, max_logit) + + idx = tl.cast(p * 200, tl.int32) + idx = tl.maximum(0, tl.minimum(idx, 199)) + sigma = tl.load(NORMAL_CDF_TO_SIGMA_TABLE + idx) + sigma = sigma + tl.abs(sigma) * -0.25 + outlier_pivot = avg_logit + std_logit * sigma + + outlier_prob = tl.exp(outlier_pivot - max_sample) / sum_exp_logits + sum_outlier_probs = 0.0 + num_outliers = tl.zeros((), dtype=tl.uint32) + + # Second pass: Calculate softmax and gather outliers + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + + outlier_mask = (probs_blk > outlier_prob) & mask_n + sum_outlier_probs += tl.sum(outlier_mask * probs_blk) + cumulative_pos = tl.cast(tl.cumsum(outlier_mask) - 1 + num_outliers, tl.int32) + num_outliers += tl.sum(outlier_mask) + write_pos = tl.where(outlier_mask, cumulative_pos, -1) + tl.store(BUFFER_ROW + write_pos, probs_blk, mask=outlier_mask) + + max_range = tl.exp(max_logit - max_sample) / sum_exp_logits + min_range = tl.exp(min_logit - max_sample) / sum_exp_logits + + p_pivot = 1.0 + num_iters = 0 + min_larger_prob = 1.0 + num_min_larger = tl.zeros((), dtype=tl.uint32) + p_pivots_sum = 0.0 + + # Third pass: Search for p_pivot + if sum_outlier_probs > p: + min_range = outlier_prob + search_range = tl.cast(num_outliers, tl.int32) + search_iters = tl.cast( + (num_outliers + BLOCK_SIZE_TRUNC - 1) // BLOCK_SIZE_TRUNC, + tl.int32, + ) + + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, search_iters): + offs_n = i * BLOCK_SIZE_TRUNC + tl.arange(0, BLOCK_SIZE_TRUNC) + mask_n_2 = offs_n < search_range + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n_2, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + else: + # Re-populate the buffer with full softmax probabilities + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + + probs_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + probs_blk = tl.exp(probs_blk - max_sample) + probs_blk = probs_blk / sum_exp_logits + tl.store(BUFFER_ROW + offs_n, probs_blk, mask=mask_n) + + found_pivot = 0 + while found_pivot == 0: + p_pivot_0 = (max_range - min_range) * 1.0 / 3.0 + min_range + p_pivots_sum_0 = 0.0 + min_larger_0 = 1.0 + num_min_larger_0 = tl.zeros((), dtype=tl.uint32) + + p_pivot_1 = (max_range - min_range) * 2.0 / 3.0 + min_range + p_pivots_sum_1 = 0.0 + min_larger_1 = 1.0 + num_min_larger_1 = tl.zeros((), dtype=tl.uint32) + + # First pass: Calculate p_pivots_sum and min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + p_pivots_sum_0 += tl.sum(probs_blk * (probs_blk > p_pivot_0)) + masked_larger_0 = tl.where(probs_blk > p_pivot_0, probs_blk, 1.0) + min_larger_0 = tl.minimum(min_larger_0, tl.min(masked_larger_0)) + + p_pivots_sum_1 += tl.sum(probs_blk * (probs_blk > p_pivot_1)) + masked_larger_1 = tl.where(probs_blk > p_pivot_1, probs_blk, 1.0) + min_larger_1 = tl.minimum(min_larger_1, tl.min(masked_larger_1)) + + # Second pass: Calculate num_min_larger + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + probs_blk = tl.load(BUFFER_ROW + offs_n, mask=mask_n, other=0.0) + + num_min_larger_0 += tl.sum(tl.abs(probs_blk - min_larger_0) < 1e-9) + num_min_larger_1 += tl.sum(tl.abs(probs_blk - min_larger_1) < 1e-9) + + # Check if any of the pivots satisfy termination condition + if p_pivots_sum_1 >= p and p_pivots_sum_1 - (min_larger_1 * num_min_larger_1) < p: + p_pivot = p_pivot_1 + min_larger_prob = min_larger_1 + num_min_larger = num_min_larger_1 + p_pivots_sum = p_pivots_sum_1 + found_pivot = 1 + if p_pivots_sum_0 >= p and p_pivots_sum_0 - (min_larger_0 * num_min_larger_0) < p: + p_pivot = p_pivot_0 + min_larger_prob = min_larger_0 + num_min_larger = num_min_larger_0 + p_pivots_sum = p_pivots_sum_0 + found_pivot = 1 + + # Update range + if p_pivots_sum_1 > p: + min_range = p_pivot_1 + elif p_pivots_sum_0 > p: + min_range = p_pivot_0 + + if p_pivots_sum_0 < p: + max_range = p_pivot_0 + elif p_pivots_sum_1 < p: + max_range = p_pivot_1 + + num_iters += 1 + if (max_range - min_range) < 1e-9 or num_iters >= 18: + p_pivot = (max_range + min_range) / 2.0 + found_pivot = 1 + + duplicate_logit = tl.log(min_larger_prob * sum_exp_logits) + max_logit + num_duplicate_logit = num_min_larger + num_keep = num_duplicate_logit - tl.cast((p_pivots_sum - p) / min_larger_prob, tl.uint32) + num_kept = tl.zeros((), dtype=tl.uint32) + + # Top-p only path + final_pivot = tl.log(p_pivot * sum_exp_logits) + max_sample + + # Sixth pass: Apply mask and store final output. + # If the pivot >= max logit (or is NaN), no token would + # survive the strict `>` keep_mask. Skip masking. + # Using `not <` instead of `>=` so that NaN is also caught. + if not (final_pivot < max_logit): + final_pivot = -float("inf") + elif final_pivot != -float("inf"): + if WRITE_MASK: + MASK_ROW = MASK_OUT + row_id * VOCAB_SIZE + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + logits_blk = tl.load(LOGITS_ROW + offs_n, mask=mask_n, other=-float("inf")) + keep_mask = (logits_blk > final_pivot) & mask_n + + # Duplicate logit handling + if num_keep < num_duplicate_logit: + duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n + duplicate_count = tl.cumsum(duplicate_mask) + num_kept + duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask + duplicate_remove_mask = duplicate_mask & ~duplicate_keep_mask + num_kept += tl.sum(duplicate_keep_mask) + keep_mask = keep_mask & (~duplicate_remove_mask) + + logits_blk = tl.where(keep_mask, logits_blk, MASK_VALUE) + tl.store(LOGITS_ROW + offs_n, logits_blk, mask=mask_n) + if WRITE_MASK: + tl.store(MASK_ROW + offs_n, keep_mask, mask=mask_n) + + # When no masking was applied (final_pivot == -inf), all tokens are kept. + if WRITE_MASK and final_pivot == -float("inf"): + MASK_ROW = MASK_OUT + row_id * VOCAB_SIZE + for i in range(0, NUM_TILES): + offs_n = i * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE) + mask_n = offs_n < VOCAB_SIZE + tl.store(MASK_ROW + offs_n, mask_n, mask=mask_n) + + +def apply_top_k_top_p_triton( + logits: paddle.Tensor, + k: paddle.Tensor | None, + p: paddle.Tensor | None, + mask_value: float = float("-inf"), + return_mask: bool = False, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply combined top-k and top-p masking using Triton. + + Top-k is applied first (by logit value), then top-p is applied + to the remaining k values (by probability). + + Args: + logits: [batch_size, vocab_size] float32 tensor, modified in-place + k: [batch_size] int32 tensor of top-k values per row, or None to disable top-k + p: [batch_size] float32 tensor of top-p values per row (0 to 1), + or None to disable top-p + mask_value: Value for masked positions (default: -inf) + return_mask: If True, also return a bool mask [batch_size, vocab_size] + where True = retained token. The mask is computed inside the kernel + with zero extra memory bandwidth cost. + + Returns: + logits if return_mask is False, else (logits, mask). + """ + assert logits.ndim == 2 + assert logits.dtype == paddle.float32 + + batch_size, vocab_size = logits.shape + + topk_enabled = k is not None + topp_enabled = p is not None + + if batch_size == 0 or not (topk_enabled or topp_enabled): + if return_mask: + mask = paddle.ones(logits.shape, dtype=paddle.bool) + return logits, mask + return logits + + if k is not None: + assert k.ndim == 1 and k.shape[0] == batch_size + k_ptr = k.to(paddle.int32) + else: + k_ptr = logits # Dummy pointer (won't be read) + + if p is not None: + assert p.ndim == 1 and p.shape[0] == batch_size + p_ptr = p.to(paddle.float32) + else: + p_ptr = logits # Dummy pointer (won't be read) + + num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count + NUM_PROGRAMS = min(num_sm, batch_size) + + # Cache per-Triton Program buffer on each device. + buf_key = (logits.device, logits.dtype, vocab_size) + buffer = _TRITON_BUFFER_CACHE.get(buf_key) + if buffer is None or buffer.shape[0] < NUM_PROGRAMS: + size = min(triton.next_power_of_2(NUM_PROGRAMS), num_sm) + buffer = paddle.empty((size, vocab_size), dtype=logits.dtype) + _TRITON_BUFFER_CACHE[buf_key] = buffer + if buffer.shape[0] > NUM_PROGRAMS: + buffer = buffer[:NUM_PROGRAMS] + + # Allocate mask output if requested. + write_mask = return_mask + if write_mask: + mask_out = paddle.empty(logits.shape, dtype=paddle.int8) + else: + mask_out = logits # Dummy pointer (won't be written) + + # Cache lookup table entries on each device. + tables = _TRITON_TABLE_CACHE.get(logits.device) + if tables is None: + normal_cdf_to_sigma_table = paddle.to_tensor( + _NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place + ) + percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place) + _TRITON_TABLE_CACHE[logits.device] = ( + normal_cdf_to_sigma_table, + percentile_to_std_table, + ) + else: + normal_cdf_to_sigma_table, percentile_to_std_table = tables + + _topk_topp_kernel[(NUM_PROGRAMS,)]( + logits, + buffer, + mask_out, + percentile_to_std_table, + normal_cdf_to_sigma_table, + k_ptr, + p_ptr, + BATCH_SIZE=batch_size, + MASK_VALUE=mask_value, + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=8192, + BLOCK_SIZE_TRUNC=4096, + TOPK_ENABLED=topk_enabled, + TOPP_ENABLED=topp_enabled, + WRITE_MASK=write_mask, + ) + + if return_mask: + return logits, mask_out.astype(paddle.bool) + return logits + + +@triton.jit +def _seeded_gumbel_kernel( + OUT_ptr, + SEEDS_ptr, + stride_out_batch, + VOCAB_SIZE: tl.constexpr, + BLOCK_SIZE: tl.constexpr, +): + """Generate -log(u) with per-row Philox seeds, fully on GPU.""" + pid = tl.program_id(0) + seed = tl.load(SEEDS_ptr + pid) + seed = seed.to(tl.int32) + for start in tl.range(0, VOCAB_SIZE, BLOCK_SIZE): + offsets = start + tl.arange(0, BLOCK_SIZE) + mask = offsets < VOCAB_SIZE + u = tl.rand(seed, offsets) + u = tl.maximum(u, 1e-10) + q = -tl.log(u) + tl.store(OUT_ptr + pid * stride_out_batch + offsets, q, mask=mask) + + +def seeded_gumbel_noise(probs: paddle.Tensor, seeds: paddle.Tensor) -> paddle.Tensor: + """ + Generate Gumbel noise q = -log(u) with per-row Philox seeds on GPU. + + Args: + probs: [batch_size, vocab_size] — used only for shape/dtype. + seeds: [batch_size] int64 per-request seeds (GPU). + + Returns: + q: [batch_size, vocab_size] float tensor of Gumbel noise. + """ + batch_size, vocab_size = probs.shape + q = paddle.empty_like(probs) + BLOCK_SIZE = min(triton.next_power_of_2(vocab_size), 4096) + _seeded_gumbel_kernel[(batch_size,)]( + q, + seeds, + q.strides[0], + VOCAB_SIZE=vocab_size, + BLOCK_SIZE=BLOCK_SIZE, + ) + return q + + +def reset_buffer_cache(): + _TRITON_BUFFER_CACHE.clear() + _TRITON_TABLE_CACHE.clear() + paddle.accelerator.empty_cache() diff --git a/fastdeploy/model_executor/layers/sample/sampler.py b/fastdeploy/model_executor/layers/sample/sampler.py index acf94b2fe2b..7f5cf9c9b32 100644 --- a/fastdeploy/model_executor/layers/sample/sampler.py +++ b/fastdeploy/model_executor/layers/sample/sampler.py @@ -26,7 +26,7 @@ from paddleformers.utils.log import logger from fastdeploy.config import FDConfig -from fastdeploy.envs import FD_FILL_BITMASK_BATCH +from fastdeploy.envs import FD_FILL_BITMASK_BATCH, FD_SAMPLING_CLASS from fastdeploy.model_executor.guided_decoding import LogitsProcessorBase from fastdeploy.model_executor.layers.sample.early_stopper import ( get_early_stopper_cls_from_stragegy, @@ -41,11 +41,80 @@ speculate_insert_first_token, top_k_top_p_sampling, ) +from fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton import ( + apply_top_k_top_p_triton, + seeded_gumbel_noise, +) from fastdeploy.platforms import current_platform from fastdeploy.reasoning import ReasoningParser from fastdeploy.worker.output import LogprobsTensors, SamplerOutput +def _apply_triton_top_k_top_p( + logits: paddle.Tensor, + top_p: paddle.Tensor, + top_k: Optional[paddle.Tensor] = None, + top_k_list: Optional[list] = None, + return_mask: bool = False, +) -> paddle.Tensor | tuple[paddle.Tensor, paddle.Tensor]: + """ + Apply combined top-k/top-p masking on logits using the Triton kernel. + Masked positions are set to -inf in-place. Call this BEFORE softmax. + + Args: + return_mask: If True, return (logits, mask) where mask is a bool + tensor [B, V] computed inside the Triton kernel (zero extra cost). + + Returns: + logits if return_mask is False, else (logits, mask). + """ + batch_size = logits.shape[0] + + top_p = top_p[:batch_size].squeeze(axis=-1) + + has_top_k = top_k is not None and top_k_list and any(x > 0 for x in top_k_list) + if has_top_k: + top_k = top_k[:batch_size].squeeze(axis=-1) + else: + top_k = None + + return apply_top_k_top_p_triton(logits.astype("float32"), k=top_k, p=top_p, return_mask=return_mask) + + +def _random_sample( + probs: paddle.Tensor, + topp_seed: Optional[paddle.Tensor] = None, +) -> paddle.Tensor: + """ + Sample from probabilities using the Gumbel-max trick. + + Equivalent to multinomial sampling but avoids CPU-GPU synchronization. + When ``topp_seed`` is provided and Triton is available, a Triton kernel + generates per-row deterministic Gumbel noise using Philox PRNG entirely + on GPU, eliminating the Python for-loop and CPU-GPU sync overhead. + + Args: + probs: [batch_size, vocab_size] float32 probabilities. + topp_seed: [batch_size, 1] int64 per-request seeds, or None. + + Returns: + Token ids of shape [batch_size, 1]. + + Reference: vllm/v1/sample/ops/topk_topp_sampler.py::random_sample + """ + # Sample from Exp(1): q = -log(u), u ~ Uniform(0, 1) + if topp_seed is not None: + seeds = topp_seed[: probs.shape[0]].reshape([-1]) + if not seeds.place.is_gpu_place(): + seeds = seeds.cuda() + q = seeded_gumbel_noise(probs, seeds) + else: + u = paddle.uniform(probs.shape, dtype=probs.dtype, min=0.0, max=1.0) + q = -paddle.log(u.clip(min=1e-10)) + # Gumbel-max: argmax(probs / q) is equivalent to multinomial(probs) + return (probs / q).argmax(axis=-1).reshape([-1, 1]) + + def top_p_normalize_probs_paddle( probs: paddle.Tensor, top_ps: paddle.Tensor, @@ -94,6 +163,53 @@ def padding_sampling_params(top_p, top_k, infer_seed, seq_lens_this_time, seq_le return top_p_padding, top_k_padding, topp_seed +@paddle.no_grad() +def _compute_sampling_mask_from_probs( + probs: paddle.Tensor, + mask: paddle.Tensor, +) -> tuple[List[np.ndarray], None]: + """ + Compute sampling mask using a pre-computed boolean mask from the Triton + kernel (derived from logits space, numerically exact). + + Uses topk instead of nonzero to avoid dynamic allocation and multi-pass + scanning. For typical top-k/top-p masks (~50 True per row out of 151 K + vocab), this is ~4x faster than the nonzero + bincount approach. + + Args: + probs: [B, V] softmax probabilities (GPU). + Only used to determine real_bsz for API compatibility. + mask: [B, V] bool tensor — True = retained by top-k/top-p (GPU). + Obtained from logits != -inf BEFORE softmax. + + Returns: + - sparse_indices: List[np.ndarray], retained vocab indices per request. + - logz_per_batch: None + """ + real_bsz = probs.shape[0] + mask = mask[:real_bsz] + + # 单次 reduction:在 bool 上 sum(自动提升 int64,不会溢出; + # 不能用 float16 sum,V>65504 时会溢出为 inf) + counts = mask.sum(axis=-1).astype("int32") # [B] + + # float16: topk 不支持 bool,float16 比 int32 省一半显存且同样正确 (值仅 0/1) + mask_fp16 = mask.astype("float16") + max_k = int(counts.max().item()) # 1 个标量 D2H + + # topk 对 0/1 张量:所有 1 排在前面,输出固定形状 [B, max_k],无动态分配 + _, indices = paddle.topk(mask_fp16, k=max_k, axis=-1) # [B, max_k] + + # 批量 D2H + indices_cpu = indices.numpy() # [B, max_k] + counts_cpu = counts.numpy() # [B] + + # 按每行实际 count 截取 + sparse_indices = [indices_cpu[i, : counts_cpu[i]] for i in range(real_bsz)] + + return sparse_indices, None + + def _compute_sampling_mask( probs: paddle.Tensor, top_p: paddle.Tensor, @@ -164,9 +280,28 @@ def _compute_sampling_mask( # ------------------------------------------------------------------ # Stage 3: top-p mask on already-sorted renormed probs (no re-sort). + # + # The sampling mask must be a *superset* of the tokens that the + # actual sampling kernel (Paddle's top_p_sampling) might choose. + # Two sources of divergence make an exact match impossible: + # + # (a) Cumsum precision: the kernel uses CUB BlockScan (parallel + # prefix-sum) while we use paddle.cumsum (sequential). + # Float rounding differs by ~1e-7, shifting the boundary. + # Fix: use <= instead of < to include 1 extra token. + # + # (b) Sort tie-breaking: when multiple tokens share the same + # probability, CUB RadixSort and paddle.argsort (Thrust) + # order them differently. This reshuffles which equal-prob + # tokens sit at the boundary, altering the cumsum path. + # Fix: after computing the mask, extend it to include ALL + # tokens with prob >= the boundary token's probability. + # + # Since the mask is only used for *reporting* (not for actual + # sampling), the slight over-inclusion is harmless. # ------------------------------------------------------------------ cum_probs = paddle.cumsum(renorm_sorted_probs, axis=-1) # [B, V] - topp_mask = (cum_probs - renorm_sorted_probs) < top_p # [B, V] + topp_mask = (cum_probs - renorm_sorted_probs) <= top_p # [B, V] # When top_p[i] >= 1.0, keep the entire row. topp_mask = paddle.where( (top_p >= 1.0).expand_as(topp_mask), @@ -174,6 +309,20 @@ def _compute_sampling_mask( topp_mask, ) + # Extend mask to cover sort tie-breaking: include all tokens whose + # probability >= the boundary token's probability (last retained + # in sorted order). In descending-sorted probs this just extends + # the contiguous True block by the run of equal-prob tokens. + k_per_row = topp_mask.astype("int32").sum(axis=-1, keepdim=True) # [B,1] + # boundary_idx = last True position (k-1), clamp for safety + boundary_idx = (k_per_row - 1).clip(min=0) # [B, 1] + boundary_prob = paddle.take_along_axis( + renorm_sorted_probs, + boundary_idx, + axis=-1, + ) # [B, 1] + topp_mask = topp_mask | (renorm_sorted_probs >= boundary_prob) + # ------------------------------------------------------------------ # Stage 4: intersect on GPU, then minimal D2H. # ------------------------------------------------------------------ @@ -628,6 +777,17 @@ def forward_cuda( elif self.logprobs_mode == "processed_logits": raw_logprobs = logits.clone() + # Triton path: mask logits in-place BEFORE softmax (no probs→log round-trip). + triton_mask = None + if FD_SAMPLING_CLASS.lower() == "triton": + logits, triton_mask = _apply_triton_top_k_top_p( + logits, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + return_mask=True, + ) + probs = F.softmax(logits) probs = min_p_sampling(probs, sampling_metadata.min_p, sampling_metadata.min_p_list) @@ -637,21 +797,29 @@ def forward_cuda( sampling_mask = None logz_per_batch = None if sampling_metadata.keep_sampling_mask: - sampling_mask, logz_per_batch = _compute_sampling_mask( + if FD_SAMPLING_CLASS.lower() == "triton": + # Triton path: use pre-computed mask from logits space (exact). + sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask) + else: + sampling_mask, logz_per_batch = _compute_sampling_mask( + probs, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + ) + + if FD_SAMPLING_CLASS.lower() == "triton": + # top-k/top-p already applied on logits; directly sample. + next_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed) + else: + _, next_tokens = top_k_top_p_sampling( probs, sampling_metadata.top_p, - top_k=sampling_metadata.top_k, - top_k_list=sampling_metadata.top_k_list, + sampling_metadata.top_k, + sampling_metadata.top_k_list, + topp_seed=sampling_metadata.seed, ) - _, next_tokens = top_k_top_p_sampling( - probs, - sampling_metadata.top_p, - sampling_metadata.top_k, - sampling_metadata.top_k_list, - topp_seed=sampling_metadata.seed, - ) - logprobs_tensors = ( None if num_logprobs is None else self.gather_logprobs(raw_logprobs, num_logprobs, token_ids=next_tokens) ) @@ -878,8 +1046,6 @@ def forward_cuda( max_model_len, ) - probs = F.softmax(logits) - top_p, top_k, topp_seed = padding_sampling_params( sampling_metadata.top_p, sampling_metadata.top_k, @@ -887,9 +1053,23 @@ def forward_cuda( share_inputs["seq_lens_this_time"], share_inputs["seq_lens_encoder"], ) - _, sampled_token_ids = top_k_top_p_sampling( - probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed - ) + + if FD_SAMPLING_CLASS.lower() == "triton": + logits = _apply_triton_top_k_top_p( + logits, + top_p, + top_k=top_k, + top_k_list=sampling_metadata.top_k_list, + ) + + probs = F.softmax(logits) + + if FD_SAMPLING_CLASS.lower() == "triton": + sampled_token_ids = _random_sample(probs, topp_seed=topp_seed) + else: + _, sampled_token_ids = top_k_top_p_sampling( + probs, top_p=top_p, top_k=top_k, top_k_list=sampling_metadata.top_k_list, topp_seed=topp_seed + ) verify_scores, verify_tokens, actual_candidate_len = top_p_candidates( probs, @@ -964,7 +1144,30 @@ def forward_cuda( share_inputs["accept_num"], ) if keep_sampling_mask: - # Derive target probs from already-extracted target_logits; avoids a second kernel call. + # Expand top_p/top_k from [batch, 1] to [total_accepted, 1]. + accept_top_p = ( + sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + accept_top_k = None + if ( + sampling_metadata.top_k is not None + and sampling_metadata.top_k_list + and any(x > 0 for x in sampling_metadata.top_k_list) + ): + accept_top_k = ( + sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + ) + + # Triton path: mask target_logits and get mask from kernel directly. + spec_triton_mask = None + if FD_SAMPLING_CLASS.lower() == "triton": + target_logits, spec_triton_mask = _apply_triton_top_k_top_p( + target_logits, + accept_top_p, + top_k=accept_top_k, + top_k_list=sampling_metadata.top_k_list, + return_mask=True, + ) target_probs = F.softmax(target_logits, axis=-1) raw_logprobs = None @@ -987,23 +1190,16 @@ def forward_cuda( sampling_mask = None logz_per_batch = None if keep_sampling_mask: - # Expand top_p from [batch, 1] to [total_accepted, 1]. - accept_top_p = sampling_metadata.top_p[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) - accept_top_k = None - if ( - sampling_metadata.top_k is not None - and sampling_metadata.top_k_list - and any(x > 0 for x in sampling_metadata.top_k_list) - ): - accept_top_k = ( - sampling_metadata.top_k[:real_bsz].squeeze(1).repeat_interleave(accept_nums).unsqueeze(1) + if FD_SAMPLING_CLASS.lower() == "triton": + # Triton path: use pre-computed mask from kernel (exact). + sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(target_probs, spec_triton_mask) + else: + sampling_mask, logz_per_batch = _compute_sampling_mask( + target_probs, + accept_top_p, + top_k=accept_top_k, + top_k_list=sampling_metadata.top_k_list, ) - sampling_mask, logz_per_batch = _compute_sampling_mask( - target_probs, - accept_top_p, - top_k=accept_top_k, - top_k_list=sampling_metadata.top_k_list, - ) sampler_output = SamplerOutput( sampled_token_ids=share_inputs["accept_tokens"], @@ -1316,11 +1512,22 @@ def forward_xpu( share_inputs["output_cum_offsets"], max_model_len, ) + if FD_SAMPLING_CLASS.lower() == "triton": + logits = _apply_triton_top_k_top_p( + logits, + sampling_metadata.top_p, + top_k=sampling_metadata.top_k, + top_k_list=sampling_metadata.top_k_list, + ) + probs = F.softmax(logits) - _, next_tokens = top_k_top_p_sampling( - probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list - ) + if FD_SAMPLING_CLASS.lower() == "triton": + next_tokens = _random_sample(probs, topp_seed=sampling_metadata.seed) + else: + _, next_tokens = top_k_top_p_sampling( + probs, sampling_metadata.top_p, sampling_metadata.top_k, sampling_metadata.top_k_list + ) # TODO(chenhuan09): add support for logprobs token_ids = None logprobs_tensors = None diff --git a/tests/layers/sample/test_top_k_top_p_triton.py b/tests/layers/sample/test_top_k_top_p_triton.py new file mode 100644 index 00000000000..ff42690132c --- /dev/null +++ b/tests/layers/sample/test_top_k_top_p_triton.py @@ -0,0 +1,450 @@ +"""Unit tests for apply_top_k_top_p_triton.""" + +import os +import sys + +import numpy as np +import paddle +import pytest + +sys.path.insert(0, os.path.join(os.path.dirname(__file__), "../../..")) + +from fastdeploy.model_executor.layers.sample.ops.top_k_top_p_triton import ( + apply_top_k_top_p_triton, +) + + +@pytest.fixture(autouse=True) +def _use_gpu(): + paddle.set_device("gpu:0") + + +def _make_logits(batch_size: int, vocab_size: int, seed: int = 42) -> paddle.Tensor: + """Create deterministic random float32 logits on GPU.""" + np.random.seed(seed) + return paddle.to_tensor(np.random.randn(batch_size, vocab_size).astype("float32")) + + +# --------------------------------------------------------------------------- +# Reference implementation (CPU / NumPy) for correctness comparison +# --------------------------------------------------------------------------- + + +def _ref_top_k_top_p( + logits_np: np.ndarray, + k: np.ndarray | None, + p: np.ndarray | None, +) -> np.ndarray: + """ + Pure-NumPy reference: top-k first, then top-p on remaining tokens. + Returns masked logits (masked positions set to -inf). + """ + B, V = logits_np.shape + out = logits_np.copy() + + for i in range(B): + row = out[i] + # --- top-k --- + if k is not None: + ki = int(k[i]) + if ki < V: + threshold = np.partition(row, -ki)[-ki] + row[row < threshold] = -np.inf + # Handle duplicates at threshold: keep exactly ki + kept = np.sum(row > -np.inf) + if kept > ki: + at_thresh = np.where(row == threshold)[0] + excess = kept - ki + row[at_thresh[:excess]] = -np.inf + + # --- top-p on surviving tokens --- + if p is not None: + pi = float(p[i]) + alive = row > -np.inf + if alive.sum() > 0 and pi < 1.0: + alive_logits = row[alive] + probs = np.exp(alive_logits - alive_logits.max()) + probs /= probs.sum() + sorted_idx = np.argsort(-probs) + cum = np.cumsum(probs[sorted_idx]) + # Keep tokens until cumulative prob >= p, always keep at least 1 + cutoff = np.searchsorted(cum, pi, side="left") + 1 + if cutoff < len(sorted_idx): + remove_local = sorted_idx[cutoff:] + alive_positions = np.where(alive)[0] + row[alive_positions[remove_local]] = -np.inf + + out[i] = row + return out + + +# --------------------------------------------------------------------------- +# Top-K precision tests +# --------------------------------------------------------------------------- + + +class TestTopKPrecision: + def test_exact_top_k_values(self): + """Kept values must exactly equal the original top-k values (bitwise).""" + B, V, K = 4, 1024, 10 + logits = _make_logits(B, V, seed=0) + original = logits.clone() + apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None) + for i in range(B): + orig_topk = paddle.topk(original[i], K) + kept_vals = logits[i][logits[i] > float("-inf")] + kept_sorted = paddle.sort(kept_vals, descending=True) + orig_sorted = paddle.sort(orig_topk[0], descending=True) + np.testing.assert_array_equal( + kept_sorted.numpy()[:K], + orig_sorted.numpy(), + err_msg=f"row {i}: kept values differ from original top-k", + ) + + def test_exact_top_k_indices(self): + """Kept positions must correspond to the original top-k indices.""" + B, V, K = 4, 512, 8 + logits = _make_logits(B, V, seed=1) + original = logits.clone() + apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None) + for i in range(B): + orig_topk_idx = set(paddle.topk(original[i], K)[1].numpy().tolist()) + kept_idx = set(np.where(logits[i].numpy() > -np.inf)[0].tolist()) + assert kept_idx.issubset( + orig_topk_idx + ), f"row {i}: kept indices {kept_idx - orig_topk_idx} not in original top-k" + + def test_masked_positions_are_neg_inf(self): + """Masked positions must be exactly -inf, not just a large negative.""" + B, V, K = 2, 256, 5 + logits = _make_logits(B, V, seed=2) + apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None) + for i in range(B): + row = logits[i].numpy() + non_kept = row[np.isneginf(row)] + assert len(non_kept) >= V - K, f"row {i}: some masked values are not -inf" + + def test_per_row_different_k(self): + """Each row can have a different k value.""" + B, V = 4, 256 + ks = [1, 5, 50, 256] + logits = _make_logits(B, V, seed=3) + k = paddle.to_tensor(ks, dtype="int32") + apply_top_k_top_p_triton(logits, k=k, p=None) + for i in range(B): + non_masked = (logits[i] > float("-inf")).sum().item() + assert non_masked <= ks[i], f"row {i}: expected <= {ks[i]}, got {non_masked}" + assert non_masked > 0 + + def test_k_equals_1(self): + """k=1 should keep only the argmax.""" + B, V = 4, 512 + logits = _make_logits(B, V, seed=4) + original = logits.clone() + apply_top_k_top_p_triton(logits, k=paddle.full([B], 1, dtype="int32"), p=None) + for i in range(B): + kept = logits[i][logits[i] > float("-inf")] + assert kept.shape[0] == 1, f"row {i}: expected 1 kept, got {kept.shape[0]}" + assert kept[0].item() == original[i].max().item() + + def test_k_equals_vocab(self): + """k=vocab_size should be a no-op (all tokens kept).""" + B, V = 2, 128 + logits = _make_logits(B, V, seed=5) + original = logits.clone() + apply_top_k_top_p_triton(logits, k=paddle.full([B], V, dtype="int32"), p=None) + np.testing.assert_array_equal(logits.numpy(), original.numpy()) + + def test_duplicate_logit_values(self): + """When multiple tokens share the same logit at the k-boundary, count is still <= k.""" + B, V, K = 2, 64, 5 + logits = paddle.zeros([B, V], dtype="float32") + # Set top-5 to distinct values, rest share the 5th value + for i in range(B): + logits[i, :K] = paddle.to_tensor([10.0, 9.0, 8.0, 7.0, 6.0], dtype="float32") + logits[i, K:] = 6.0 # duplicates of the boundary value + apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None) + for i in range(B): + non_masked = (logits[i] > float("-inf")).sum().item() + assert non_masked <= K, f"row {i}: expected <= {K}, got {non_masked}" + + +# --------------------------------------------------------------------------- +# Top-P precision tests +# --------------------------------------------------------------------------- + + +class TestTopPPrecision: + def test_top_p_cumulative_probability(self): + """Kept tokens' probabilities should sum to >= p.""" + B, V = 4, 256 + p_val = 0.9 + logits = _make_logits(B, V, seed=10) + original = logits.clone() + apply_top_k_top_p_triton(logits, k=None, p=paddle.full([B], p_val, dtype="float32")) + for i in range(B): + # Compute original probs via softmax + orig_probs = paddle.nn.functional.softmax(original[i], axis=-1).numpy() + kept_mask = logits[i].numpy() > -np.inf + kept_prob = orig_probs[kept_mask].sum() + assert kept_prob >= p_val - 0.05, f"row {i}: kept prob {kept_prob:.4f} < p={p_val}" + + def test_top_p_minimality(self): + """Removing any one kept token (except the largest) should drop cumsum below p.""" + B, V = 4, 512 + p_val = 0.8 + logits = _make_logits(B, V, seed=11) + original = logits.clone() + apply_top_k_top_p_triton(logits, k=None, p=paddle.full([B], p_val, dtype="float32")) + for i in range(B): + orig_probs = paddle.nn.functional.softmax(original[i], axis=-1).numpy() + kept_mask = logits[i].numpy() > -np.inf + kept_probs = orig_probs[kept_mask] + total = kept_probs.sum() + if len(kept_probs) <= 1: + continue + # The smallest kept token should be necessary (or near-necessary) + smallest_kept = kept_probs.min() + # Removing the smallest should bring total close to or below p + assert total - smallest_kept <= p_val + 0.05, ( + f"row {i}: removing smallest kept still leaves " f"{total - smallest_kept:.4f} > p={p_val}+tolerance" + ) + + def test_top_p_various_values(self): + """Test multiple p values: kept count should increase with p.""" + B, V = 1, 512 + logits_base = _make_logits(B, V, seed=12) + counts = [] + for p_val in [0.1, 0.5, 0.9, 1.0]: + logits = logits_base.clone() + apply_top_k_top_p_triton(logits, k=None, p=paddle.full([B], p_val, dtype="float32")) + cnt = (logits[0] > float("-inf")).sum().item() + counts.append(cnt) + # Monotonically non-decreasing + for j in range(len(counts) - 1): + assert counts[j] <= counts[j + 1], f"kept count not monotonic: p values -> counts {counts}" + # p=1.0 should keep all + assert counts[-1] == V + + def test_top_p_very_small(self): + """Very small p should keep very few tokens (often just 1).""" + B, V = 4, 1024 + logits = _make_logits(B, V, seed=13) + apply_top_k_top_p_triton(logits, k=None, p=paddle.full([B], 0.01, dtype="float32")) + for i in range(B): + non_masked = (logits[i] > float("-inf")).sum().item() + assert non_masked >= 1 + assert non_masked <= 20, f"row {i}: p=0.01 kept {non_masked} tokens" + + +# --------------------------------------------------------------------------- +# Combined top-k + top-p precision tests +# --------------------------------------------------------------------------- + + +class TestCombinedPrecision: + def test_combined_vs_sequential(self): + """Triton combined result should match top-k-then-top-p applied independently.""" + B, V, K = 4, 512, 50 + p_val = 0.9 + logits = _make_logits(B, V, seed=20) + original_np = logits.numpy().copy() + + # --- Reference: top-k first, then top-p --- + ref = _ref_top_k_top_p( + original_np, + k=np.full(B, K, dtype=np.int32), + p=np.full(B, p_val, dtype=np.float32), + ) + + # --- Triton --- + apply_top_k_top_p_triton( + logits, + k=paddle.full([B], K, dtype="int32"), + p=paddle.full([B], p_val, dtype="float32"), + ) + triton_np = logits.numpy() + + for i in range(B): + ref_kept = set(np.where(ref[i] > -np.inf)[0]) + tri_kept = set(np.where(triton_np[i] > -np.inf)[0]) + # Allow small difference due to softmax precision + sym_diff = ref_kept.symmetric_difference(tri_kept) + assert len(sym_diff) <= 3, f"row {i}: ref vs triton kept sets differ by {len(sym_diff)} tokens: {sym_diff}" + + def test_combined_top_p_further_reduces_top_k(self): + """With small p, combined should keep fewer tokens than top-k alone.""" + B, V, K = 4, 256, 50 + logits1 = _make_logits(B, V, seed=21) + logits2 = logits1.clone() + + # top-k only + apply_top_k_top_p_triton(logits1, k=paddle.full([B], K, dtype="int32"), p=None) + # top-k + top-p + apply_top_k_top_p_triton( + logits2, + k=paddle.full([B], K, dtype="int32"), + p=paddle.full([B], 0.5, dtype="float32"), + ) + for i in range(B): + cnt_k = (logits1[i] > float("-inf")).sum().item() + cnt_kp = (logits2[i] > float("-inf")).sum().item() + assert cnt_kp <= cnt_k, f"row {i}: combined ({cnt_kp}) should be <= top-k only ({cnt_k})" + + def test_combined_per_row_mixed_params(self): + """Different k and p per row.""" + B, V = 4, 512 + ks = [5, 20, 100, 512] + ps = [0.3, 0.5, 0.9, 1.0] + logits = _make_logits(B, V, seed=22) + k = paddle.to_tensor(ks, dtype="int32") + p = paddle.to_tensor(ps, dtype="float32") + apply_top_k_top_p_triton(logits, k=k, p=p) + for i in range(B): + non_masked = (logits[i] > float("-inf")).sum().item() + assert non_masked <= ks[i], f"row {i}: expected <= {ks[i]}, got {non_masked}" + assert non_masked >= 1 + + def test_kept_values_unchanged(self): + """Kept (non-masked) logit values must be bitwise identical to original.""" + B, V, K = 4, 256, 20 + logits = _make_logits(B, V, seed=23) + original = logits.clone() + apply_top_k_top_p_triton( + logits, + k=paddle.full([B], K, dtype="int32"), + p=paddle.full([B], 0.8, dtype="float32"), + ) + for i in range(B): + kept_mask = logits[i].numpy() > -np.inf + np.testing.assert_array_equal( + logits[i].numpy()[kept_mask], + original[i].numpy()[kept_mask], + err_msg=f"row {i}: kept logit values were modified", + ) + + +# --------------------------------------------------------------------------- +# Large vocab / batch stress tests +# --------------------------------------------------------------------------- + + +class TestLargeScale: + def test_large_vocab(self): + """Test with a realistic vocab size (32000).""" + B, V, K = 2, 32000, 50 + logits = _make_logits(B, V, seed=30) + original = logits.clone() + apply_top_k_top_p_triton( + logits, + k=paddle.full([B], K, dtype="int32"), + p=paddle.full([B], 0.9, dtype="float32"), + ) + for i in range(B): + non_masked = (logits[i] > float("-inf")).sum().item() + assert 1 <= non_masked <= K + # Verify kept values are unchanged + kept_mask = logits[i].numpy() > -np.inf + np.testing.assert_array_equal( + logits[i].numpy()[kept_mask], + original[i].numpy()[kept_mask], + ) + + def test_large_batch(self): + """Test with a large batch (128 rows).""" + B, V, K = 128, 1024, 10 + logits = _make_logits(B, V, seed=31) + apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None) + for i in range(B): + non_masked = (logits[i] > float("-inf")).sum().item() + assert non_masked <= K + assert non_masked >= 1 + + +# --------------------------------------------------------------------------- +# Edge cases +# --------------------------------------------------------------------------- + + +class TestEdgeCases: + def test_empty_batch(self): + """Empty batch should return immediately.""" + logits = paddle.empty([0, 128], dtype="float32") + out = apply_top_k_top_p_triton(logits, k=None, p=None) + assert out.shape == [0, 128] + + def test_no_filtering(self): + """Both k and p as None should be a no-op.""" + B, V = 2, 64 + logits = _make_logits(B, V, seed=40) + original = logits.clone() + out = apply_top_k_top_p_triton(logits, k=None, p=None) + np.testing.assert_array_equal(out.numpy(), original.numpy()) + + def test_single_row(self): + """Batch size 1 should work correctly.""" + B, V, K = 1, 256, 3 + logits = _make_logits(B, V, seed=41) + apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None) + assert (logits[0] > float("-inf")).sum().item() <= K + + def test_inplace_returns_same_tensor(self): + """Return value should be the same tensor object (in-place).""" + B, V = 2, 64 + logits = _make_logits(B, V, seed=42) + out = apply_top_k_top_p_triton(logits, k=paddle.full([B], 5, dtype="int32"), p=None) + assert out.data_ptr() == logits.data_ptr() + + +class TestReturnMask: + """Tests for return_mask=True — mask output from the Triton kernel.""" + + def test_mask_shape_and_dtype(self): + """Mask should be [B, V] bool.""" + B, V, K = 4, 256, 10 + logits = _make_logits(B, V, seed=50) + _, mask = apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None, return_mask=True) + assert mask.shape == [B, V] + assert mask.dtype == paddle.bool + + def test_mask_matches_logits(self): + """Mask True positions should match logits != -inf exactly.""" + B, V, K = 8, 512, 5 + logits = _make_logits(B, V, seed=51) + out, mask = apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None, return_mask=True) + expected = out != float("-inf") + np.testing.assert_array_equal(mask.numpy(), expected.numpy()) + + def test_mask_top_p(self): + """Mask with top-p should match logits != -inf.""" + B, V = 4, 256 + logits = _make_logits(B, V, seed=52) + p = paddle.full([B], 0.9, dtype="float32") + out, mask = apply_top_k_top_p_triton(logits, k=None, p=p, return_mask=True) + expected = out != float("-inf") + np.testing.assert_array_equal(mask.numpy(), expected.numpy()) + + def test_mask_combined(self): + """Mask with combined top-k + top-p should match logits != -inf.""" + B, V = 4, 256 + logits = _make_logits(B, V, seed=53) + k = paddle.to_tensor([5, 10, 20, 50], dtype="int32") + p = paddle.to_tensor([0.8, 0.9, 0.95, 1.0], dtype="float32") + out, mask = apply_top_k_top_p_triton(logits, k=k, p=p, return_mask=True) + expected = out != float("-inf") + np.testing.assert_array_equal(mask.numpy(), expected.numpy()) + + def test_mask_no_filtering(self): + """When no filtering is needed, mask should be all True.""" + B, V = 2, 64 + logits = _make_logits(B, V, seed=54) + out, mask = apply_top_k_top_p_triton(logits, k=None, p=None, return_mask=True) + assert mask.all().item() + + def test_mask_count_matches_top_k(self): + """Number of True values per row should be <= k.""" + B, V, K = 4, 512, 8 + logits = _make_logits(B, V, seed=55) + _, mask = apply_top_k_top_p_triton(logits, k=paddle.full([B], K, dtype="int32"), p=None, return_mask=True) + for i in range(B): + assert mask[i].astype("int32").sum().item() <= K