Skip to content

add triton backend sampler#7015

Open
yuanlehome wants to merge 2 commits intoPaddlePaddle:release/2.4from
yuanlehome:add_triton_backend_sampler
Open

add triton backend sampler#7015
yuanlehome wants to merge 2 commits intoPaddlePaddle:release/2.4from
yuanlehome:add_triton_backend_sampler

Conversation

@yuanlehome
Copy link
Collaborator

Motivation

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 25, 2026 11:46
@paddle-bot
Copy link

paddle-bot bot commented Mar 25, 2026

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 旨在为 FastDeploy 的采样流程新增一个基于 Triton 的 top-k/top-p 联合截断实现,并在 Sampler 中通过环境变量切换到该路径,以期减少采样阶段的额外算子开销。

Changes:

  • 新增 Triton kernel:对 logits 做 top-k→top-p 的 in-place mask,并可选返回候选 mask。
  • sampler.py 中增加 FD_SAMPLING_CLASS=triton 分支:mask logits 后 softmax,再进行采样与 sampling_mask 产出。
  • 增加对应单测文件,并更新 env 注释与 flake8 针对新文件的忽略规则。

Reviewed changes

Copilot reviewed 5 out of 5 changed files in this pull request and generated 12 comments.

Show a summary per file
File Description
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py 新增 Triton top-k/top-p 联合截断 kernel 与缓存管理
fastdeploy/model_executor/layers/sample/sampler.py 增加 Triton 采样路径、mask/logZ 计算与随机采样实现
tests/layers/sample/test_top_k_top_p_triton.py 新增 Triton top-k/top-p 的正确性与边界用例测试
fastdeploy/envs.py 扩展 FD_SAMPLING_CLASS 注释,加入 triton 选项
.flake8 为新 Triton 文件添加 flake8 per-file ignore 配置

Comment on lines +844 to +851
# Cache lookup table entries on each device.
tables = _TRITON_TABLE_CACHE.get(logits.device)
if tables is None:
normal_cdf_to_sigma_table = paddle.to_tensor(
_NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place
)
percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place)
_TRITON_TABLE_CACHE[logits.device] = (
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

lookup table 的缓存同样以 logits.device 作为 key,但张量创建时指定的是 place=logits.place。建议改为使用 logits.place(或当前 device id)作为 _TRITON_TABLE_CACHE 的 key,避免 device 表示差异导致重复缓存或取错表。

Suggested change
# Cache lookup table entries on each device.
tables = _TRITON_TABLE_CACHE.get(logits.device)
if tables is None:
normal_cdf_to_sigma_table = paddle.to_tensor(
_NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place
)
percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place)
_TRITON_TABLE_CACHE[logits.device] = (
# Cache lookup table entries on each device/place.
tables = _TRITON_TABLE_CACHE.get(logits.place)
if tables is None:
normal_cdf_to_sigma_table = paddle.to_tensor(
_NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place
)
percentile_to_std_table = paddle.to_tensor(
_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place
)
_TRITON_TABLE_CACHE[logits.place] = (

Copilot uses AI. Check for mistakes.
Comment on lines +747 to +748
# Triton path: use pre-computed mask from logits space (exact).
sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask)
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 Triton 路径下 sampling_mask 直接来自 triton_mask(top-k/top-p),但上面 probs 已经经过 min_p_sampling() 进一步裁剪;因此 keep_sampling_mask/logZ 会忽略 min_p 的影响而与实际采样候选集不一致。建议当 min_p 生效时回退到 _compute_sampling_mask,或将 triton_mask 与 min_p 生成的 mask 在 GPU 上求交后再计算 logZ。

Suggested change
# Triton path: use pre-computed mask from logits space (exact).
sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask)
# Triton path:
# - If min_p is enabled, fall back to _compute_sampling_mask so that
# sampling_mask/logZ reflect both top-k/top-p and min_p pruning.
# - Otherwise, we can rely on the pre-computed Triton mask from logits space.
if sampling_metadata.min_p is not None or sampling_metadata.min_p_list is not None:
sampling_mask, logz_per_batch = _compute_sampling_mask(
probs,
sampling_metadata.top_p,
top_k=sampling_metadata.top_k,
top_k_list=sampling_metadata.top_k_list,
)
else:
sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask)

Copilot uses AI. Check for mistakes.
Comment on lines +53 to 54
# Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently.
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 描述仍是模板占位内容,缺少至少“Motivation/Modifications/Usage/Accuracy Tests”中的必要信息;这会影响评审与回归验证。建议补充:为何引入 Triton sampler、如何启用(FD_SAMPLING_CLASS=triton)、以及基本正确性/性能/精度验证结果或未提供的原因。

Copilot uses AI. Check for mistakes.
Comment on lines +750 to +753
if num_keep < num_duplicate_logit:
duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n
duplicate_count = tl.cumsum(duplicate_mask) + num_kept
duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里的 duplicate 处理使用了 num_duplicate_logit 作为保留上限,但该分支前面已经计算了 num_keep(需要保留的 duplicate 数)。继续用 num_duplicate_logit 会导致边界重复值被过量保留,从而违反 top-k/top-p 的截断语义。建议将 duplicate_keep_mask 的阈值改为 num_keep,并确保 num_kept 的累加与前面逻辑一致。

Suggested change
if num_keep < num_duplicate_logit:
duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n
duplicate_count = tl.cumsum(duplicate_mask) + num_kept
duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask
# Only process duplicates while we still need to keep more of them.
if num_kept < num_keep:
duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n
duplicate_count = tl.cumsum(duplicate_mask) + num_kept
# Cap the number of kept duplicates by num_keep (the desired keep count),
# not by num_duplicate_logit (the total number of duplicates).
duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask

Copilot uses AI. Check for mistakes.
else:
p_ptr = logits # Dummy pointer (won't be read)

num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

num_sm 的获取方式与仓库其他 Triton 代码不一致:这里调用 paddle.device.cuda.get_device_properties(logits.device.index),而现有代码通常直接用 paddle.device.cuda.get_device_properties()(依赖当前 device),并且 Paddle Tensor 未必有 device.index。建议对齐现有实现:要么直接不传参读取当前 device,要么从 logits.place 中提取 device id。

Suggested change
num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count
# Use current CUDA device properties, consistent with other Triton code.
num_sm = paddle.device.cuda.get_device_properties().multi_processor_count

Copilot uses AI. Check for mistakes.
# E402: module level import not at top of file
per-file-ignores =
__init__.py:F401,F403,E402
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

.flake8 对整个 top_k_top_p_triton.py 忽略 E121/E131 这类缩进/对齐规则会掩盖真实的缩进错误(不仅是表格对齐)。如果只是为了常量表的列对齐,建议改成在表格段落局部使用 # noqa: E241 等,或仅关闭必要的规则,避免对整文件放开缩进检查。

Suggested change
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E266

Copilot uses AI. Check for mistakes.
Comment on lines +103 to +107
seeds = topp_seed[: probs.shape[0]].reshape([-1]).tolist()
q = paddle.empty_like(probs)
for i, s in enumerate(seeds):
paddle.seed(int(s))
q[i] = -paddle.log(paddle.uniform([probs.shape[1]], dtype=probs.dtype, min=0.0, max=1.0).clip(min=1e-10))
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

topp_seed 分支里把 GPU Tensor 转成 seeds.tolist() 并逐行循环 paddle.seed + 单行 uniform,会引入 D2H 同步和大量 kernel launch;同时与注释中“避免 CPU-GPU 同步”的承诺相反。建议尽量使用 Paddle 已有的 GPU sampling(例如 logits 已 mask 后调用 paddle.tensor.top_p_sampling/top_p=1 或其他向量化采样)来保持吞吐。

Copilot uses AI. Check for mistakes.
Comment on lines +53 to 54
# Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently.
"FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"),
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

PR 标题目前不符合仓库约定的“[CLASS]Title”格式(例如 [Feature] Add triton backend sampler)。请按模板中的 tag 列表补充一个明确的前缀,并确保 Title 首字母/动词风格统一。

Copilot uses AI. Check for mistakes.
Comment on lines +828 to +832
buf_key = (logits.device, logits.dtype, vocab_size)
buffer = _TRITON_BUFFER_CACHE.get(buf_key)
if buffer is None or buffer.shape[0] < NUM_PROGRAMS:
size = min(triton.next_power_of_2(NUM_PROGRAMS), num_sm)
buffer = paddle.empty((size, vocab_size), dtype=logits.dtype)
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

缓存 buffer 的 key 使用 logits.device,但本函数创建/放置 Tensor 时用的是 logits.place;device/place 混用容易造成缓存命中异常或跨 device 复用错误。建议统一用 logits.place(或稳定的 device id)作为 buf_key,并与后面的 table cache key 保持一致。

Copilot uses AI. Check for mistakes.
def reset_buffer_cache():
_TRITON_BUFFER_CACHE.clear()
_TRITON_TABLE_CACHE.clear()
paddle.accelerator.empty_cache()
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

reset_buffer_cache() 里调用 paddle.accelerator.empty_cache(),但仓库其他位置普遍使用 paddle.device.cuda.empty_cache() 或 paddle.device.empty_cache()。如果 paddle.accelerator 在目标 Paddle 版本不存在会直接报错。建议改为与全局一致的 empty_cache API。

Suggested change
paddle.accelerator.empty_cache()
paddle.device.empty_cache()

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 11.24361% with 521 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (release/2.4@4516c58). Learn more about missing BASE report.

Files with missing lines Patch % Lines
...l_executor/layers/sample/ops/top_k_top_p_triton.py 9.92% 481 Missing ⚠️
fastdeploy/model_executor/layers/sample/sampler.py 24.52% 33 Missing and 7 partials ⚠️
Additional details and impacted files
@@              Coverage Diff               @@
##             release/2.4    #7015   +/-   ##
==============================================
  Coverage               ?   55.77%           
==============================================
  Files                  ?      334           
  Lines                  ?    43146           
  Branches               ?     6581           
==============================================
  Hits                   ?    24066           
  Misses                 ?    17193           
  Partials               ?     1887           
Flag Coverage Δ
GPU 55.77% <11.24%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants