add triton backend sampler#7015
Conversation
|
Thanks for your contribution! |
There was a problem hiding this comment.
Pull request overview
该 PR 旨在为 FastDeploy 的采样流程新增一个基于 Triton 的 top-k/top-p 联合截断实现,并在 Sampler 中通过环境变量切换到该路径,以期减少采样阶段的额外算子开销。
Changes:
- 新增 Triton kernel:对 logits 做 top-k→top-p 的 in-place mask,并可选返回候选 mask。
- 在
sampler.py中增加FD_SAMPLING_CLASS=triton分支:mask logits 后 softmax,再进行采样与 sampling_mask 产出。 - 增加对应单测文件,并更新 env 注释与 flake8 针对新文件的忽略规则。
Reviewed changes
Copilot reviewed 5 out of 5 changed files in this pull request and generated 12 comments.
Show a summary per file
| File | Description |
|---|---|
fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py |
新增 Triton top-k/top-p 联合截断 kernel 与缓存管理 |
fastdeploy/model_executor/layers/sample/sampler.py |
增加 Triton 采样路径、mask/logZ 计算与随机采样实现 |
tests/layers/sample/test_top_k_top_p_triton.py |
新增 Triton top-k/top-p 的正确性与边界用例测试 |
fastdeploy/envs.py |
扩展 FD_SAMPLING_CLASS 注释,加入 triton 选项 |
.flake8 |
为新 Triton 文件添加 flake8 per-file ignore 配置 |
| # Cache lookup table entries on each device. | ||
| tables = _TRITON_TABLE_CACHE.get(logits.device) | ||
| if tables is None: | ||
| normal_cdf_to_sigma_table = paddle.to_tensor( | ||
| _NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place | ||
| ) | ||
| percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place) | ||
| _TRITON_TABLE_CACHE[logits.device] = ( |
There was a problem hiding this comment.
lookup table 的缓存同样以 logits.device 作为 key,但张量创建时指定的是 place=logits.place。建议改为使用 logits.place(或当前 device id)作为 _TRITON_TABLE_CACHE 的 key,避免 device 表示差异导致重复缓存或取错表。
| # Cache lookup table entries on each device. | |
| tables = _TRITON_TABLE_CACHE.get(logits.device) | |
| if tables is None: | |
| normal_cdf_to_sigma_table = paddle.to_tensor( | |
| _NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place | |
| ) | |
| percentile_to_std_table = paddle.to_tensor(_PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place) | |
| _TRITON_TABLE_CACHE[logits.device] = ( | |
| # Cache lookup table entries on each device/place. | |
| tables = _TRITON_TABLE_CACHE.get(logits.place) | |
| if tables is None: | |
| normal_cdf_to_sigma_table = paddle.to_tensor( | |
| _NORMAL_CDF_TO_SIGMA_TABLE, dtype=logits.dtype, place=logits.place | |
| ) | |
| percentile_to_std_table = paddle.to_tensor( | |
| _PERCENTILE_TO_STD_TABLE, dtype=logits.dtype, place=logits.place | |
| ) | |
| _TRITON_TABLE_CACHE[logits.place] = ( |
| # Triton path: use pre-computed mask from logits space (exact). | ||
| sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask) |
There was a problem hiding this comment.
这里 Triton 路径下 sampling_mask 直接来自 triton_mask(top-k/top-p),但上面 probs 已经经过 min_p_sampling() 进一步裁剪;因此 keep_sampling_mask/logZ 会忽略 min_p 的影响而与实际采样候选集不一致。建议当 min_p 生效时回退到 _compute_sampling_mask,或将 triton_mask 与 min_p 生成的 mask 在 GPU 上求交后再计算 logZ。
| # Triton path: use pre-computed mask from logits space (exact). | |
| sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask) | |
| # Triton path: | |
| # - If min_p is enabled, fall back to _compute_sampling_mask so that | |
| # sampling_mask/logZ reflect both top-k/top-p and min_p pruning. | |
| # - Otherwise, we can rely on the pre-computed Triton mask from logits space. | |
| if sampling_metadata.min_p is not None or sampling_metadata.min_p_list is not None: | |
| sampling_mask, logz_per_batch = _compute_sampling_mask( | |
| probs, | |
| sampling_metadata.top_p, | |
| top_k=sampling_metadata.top_k, | |
| top_k_list=sampling_metadata.top_k_list, | |
| ) | |
| else: | |
| sampling_mask, logz_per_batch = _compute_sampling_mask_from_probs(probs, triton_mask) |
| # Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently. | ||
| "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), |
There was a problem hiding this comment.
PR 描述仍是模板占位内容,缺少至少“Motivation/Modifications/Usage/Accuracy Tests”中的必要信息;这会影响评审与回归验证。建议补充:为何引入 Triton sampler、如何启用(FD_SAMPLING_CLASS=triton)、以及基本正确性/性能/精度验证结果或未提供的原因。
| if num_keep < num_duplicate_logit: | ||
| duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n | ||
| duplicate_count = tl.cumsum(duplicate_mask) + num_kept | ||
| duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask |
There was a problem hiding this comment.
这里的 duplicate 处理使用了 num_duplicate_logit 作为保留上限,但该分支前面已经计算了 num_keep(需要保留的 duplicate 数)。继续用 num_duplicate_logit 会导致边界重复值被过量保留,从而违反 top-k/top-p 的截断语义。建议将 duplicate_keep_mask 的阈值改为 num_keep,并确保 num_kept 的累加与前面逻辑一致。
| if num_keep < num_duplicate_logit: | |
| duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n | |
| duplicate_count = tl.cumsum(duplicate_mask) + num_kept | |
| duplicate_keep_mask = (duplicate_count <= num_duplicate_logit) & duplicate_mask | |
| # Only process duplicates while we still need to keep more of them. | |
| if num_kept < num_keep: | |
| duplicate_mask = (tl.abs(logits_blk - duplicate_logit) < 1e-9) & mask_n | |
| duplicate_count = tl.cumsum(duplicate_mask) + num_kept | |
| # Cap the number of kept duplicates by num_keep (the desired keep count), | |
| # not by num_duplicate_logit (the total number of duplicates). | |
| duplicate_keep_mask = (duplicate_count <= num_keep) & duplicate_mask |
| else: | ||
| p_ptr = logits # Dummy pointer (won't be read) | ||
|
|
||
| num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count |
There was a problem hiding this comment.
num_sm 的获取方式与仓库其他 Triton 代码不一致:这里调用 paddle.device.cuda.get_device_properties(logits.device.index),而现有代码通常直接用 paddle.device.cuda.get_device_properties()(依赖当前 device),并且 Paddle Tensor 未必有 device.index。建议对齐现有实现:要么直接不传参读取当前 device,要么从 logits.place 中提取 device id。
| num_sm = paddle.device.cuda.get_device_properties(logits.device.index).multi_processor_count | |
| # Use current CUDA device properties, consistent with other Triton code. | |
| num_sm = paddle.device.cuda.get_device_properties().multi_processor_count |
| # E402: module level import not at top of file | ||
| per-file-ignores = | ||
| __init__.py:F401,F403,E402 | ||
| fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266 |
There was a problem hiding this comment.
.flake8 对整个 top_k_top_p_triton.py 忽略 E121/E131 这类缩进/对齐规则会掩盖真实的缩进错误(不仅是表格对齐)。如果只是为了常量表的列对齐,建议改成在表格段落局部使用 # noqa: E241 等,或仅关闭必要的规则,避免对整文件放开缩进检查。
| fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E121,E131,E266 | |
| fastdeploy/model_executor/layers/sample/ops/top_k_top_p_triton.py:E241,E266 |
| seeds = topp_seed[: probs.shape[0]].reshape([-1]).tolist() | ||
| q = paddle.empty_like(probs) | ||
| for i, s in enumerate(seeds): | ||
| paddle.seed(int(s)) | ||
| q[i] = -paddle.log(paddle.uniform([probs.shape[1]], dtype=probs.dtype, min=0.0, max=1.0).clip(min=1e-10)) |
There was a problem hiding this comment.
topp_seed 分支里把 GPU Tensor 转成 seeds.tolist() 并逐行循环 paddle.seed + 单行 uniform,会引入 D2H 同步和大量 kernel launch;同时与注释中“避免 CPU-GPU 同步”的承诺相反。建议尽量使用 Paddle 已有的 GPU sampling(例如 logits 已 mask 后调用 paddle.tensor.top_p_sampling/top_p=1 或其他向量化采样)来保持吞吐。
| # Set sampling class. "base", "base_non_truncated", "air", "rejection" and "triton" can be set currently. | ||
| "FD_SAMPLING_CLASS": lambda: os.getenv("FD_SAMPLING_CLASS", "base"), |
There was a problem hiding this comment.
PR 标题目前不符合仓库约定的“[CLASS]Title”格式(例如 [Feature] Add triton backend sampler)。请按模板中的 tag 列表补充一个明确的前缀,并确保 Title 首字母/动词风格统一。
| buf_key = (logits.device, logits.dtype, vocab_size) | ||
| buffer = _TRITON_BUFFER_CACHE.get(buf_key) | ||
| if buffer is None or buffer.shape[0] < NUM_PROGRAMS: | ||
| size = min(triton.next_power_of_2(NUM_PROGRAMS), num_sm) | ||
| buffer = paddle.empty((size, vocab_size), dtype=logits.dtype) |
There was a problem hiding this comment.
缓存 buffer 的 key 使用 logits.device,但本函数创建/放置 Tensor 时用的是 logits.place;device/place 混用容易造成缓存命中异常或跨 device 复用错误。建议统一用 logits.place(或稳定的 device id)作为 buf_key,并与后面的 table cache key 保持一致。
| def reset_buffer_cache(): | ||
| _TRITON_BUFFER_CACHE.clear() | ||
| _TRITON_TABLE_CACHE.clear() | ||
| paddle.accelerator.empty_cache() |
There was a problem hiding this comment.
reset_buffer_cache() 里调用 paddle.accelerator.empty_cache(),但仓库其他位置普遍使用 paddle.device.cuda.empty_cache() 或 paddle.device.empty_cache()。如果 paddle.accelerator 在目标 Paddle 版本不存在会直接报错。建议改为与全局一致的 empty_cache API。
| paddle.accelerator.empty_cache() | |
| paddle.device.empty_cache() |
Codecov Report❌ Patch coverage is Additional details and impacted files@@ Coverage Diff @@
## release/2.4 #7015 +/- ##
==============================================
Coverage ? 55.77%
==============================================
Files ? 334
Lines ? 43146
Branches ? 6581
==============================================
Hits ? 24066
Misses ? 17193
Partials ? 1887
Flags with carried forward coverage won't be shown. Click here to find out more. ☔ View full report in Codecov by Sentry. 🚀 New features to boost your workflow:
|
Motivation
Modifications
Usage or Command
Accuracy Tests
Checklist
[FDConfig],[APIServer],[Engine],[Scheduler],[PD Disaggregation],[Executor],[Graph Optimization],[Speculative Decoding],[RL],[Models],[Quantization],[Loader],[OP],[KVCache],[DataProcessor],[BugFix],[Docs],[CI],[Optimization],[Feature],[Benchmark],[Others],[XPU],[HPU],[GCU],[DCU],[Iluvatar],[Metax]]pre-commitbefore commit.releasebranch, make sure the PR has been submitted to thedevelopbranch, then cherry-pick it to thereleasebranch with the[Cherry-Pick]PR tag.