Skip to content

[Speculative Decoding] Support mtp expert-parallel and support different modality deploy#7018

Open
freeliuzc wants to merge 2 commits intoPaddlePaddle:developfrom
freeliuzc:merge_support_diff_mm_deploy
Open

[Speculative Decoding] Support mtp expert-parallel and support different modality deploy#7018
freeliuzc wants to merge 2 commits intoPaddlePaddle:developfrom
freeliuzc:merge_support_diff_mm_deploy

Conversation

@freeliuzc
Copy link
Collaborator

@freeliuzc freeliuzc commented Mar 25, 2026

Motivation

  1. 修复 MTP 在 gpu_model_runner 的空跑问题
  2. 支持不同模态极限性能部署
  3. 支持 Decoder 节点正确计时打点

💡 If this PR is a Cherry Pick, the PR title needs to follow the format by adding the [Cherry-Pick] label at the very beginning and appending the original PR ID at the end. For example, [Cherry-Pick][CI] Add check trigger and logic(#5191)

💡 如若此PR是Cherry Pick,PR标题需遵循格式,在最开始加上[Cherry-Pick]标签,以及最后面加上原PR ID,例如[Cherry-Pick][CI] Add check trigger and logic(#5191)

Modifications

Usage or Command

Accuracy Tests

Checklist

  • Add at least a tag in the PR title.
    • Tag list: [[FDConfig],[APIServer],[Engine], [Scheduler], [PD Disaggregation], [Executor], [Graph Optimization], [Speculative Decoding], [RL], [Models], [Quantization], [Loader], [OP], [KVCache], [DataProcessor], [BugFix], [Docs], [CI], [Optimization], [Feature], [Benchmark], [Others], [XPU], [HPU], [GCU], [DCU], [Iluvatar], [Metax]]
    • You can add new tags based on the PR content, but the semantics must be clear.
  • Format your code, run pre-commit before commit.
  • Add unit tests. Please write the reason in this PR if no unit tests.
  • Provide accuracy results.
  • If the current PR is submitting to the release branch, make sure the PR has been submitted to the develop branch, then cherry-pick it to the release branch with the [Cherry-Pick] PR tag.

Copilot AI review requested due to automatic review settings March 25, 2026 12:46
@paddle-bot
Copy link

paddle-bot bot commented Mar 25, 2026

Thanks for your contribution!

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull request overview

该 PR 在 FastDeploy 推理链路中新增“部署模态(text/mixed)”配置,用于在保持多模态模型能力的同时支持 text-only 的极限资源/性能部署;同时修复 MTP + EP 场景下 worker 空跑导致的 forward 问题,并补齐 decode 节点的打点字段。

Changes:

  • 新增 DeployModality 枚举并贯通 engine CLI → worker 启动参数 → FDConfig,用于区分 text-only 与 mixed 部署。
  • 在 input batch / MTP 推理路径中按 deploy modality 跳过部分 multimodal attention mask offset 相关 buffer 与逻辑。
  • 修复 MTP + expert-parallel 下空 batch/空输出时的 empty forward 执行,并补充 decoder 节点 metrics 时间戳。

Reviewed changes

Copilot reviewed 10 out of 10 changed files in this pull request and generated 1 comment.

Show a summary per file
File Description
fastdeploy/config.py 新增 DeployModality 枚举,并在 FDConfig 中新增 deploy_modality 字段。
fastdeploy/engine/args_utils.py engine CLI 新增 --deploy-modality,并在创建 FDConfig 时解析为 DeployModality
fastdeploy/engine/engine.py 启动 worker 时透传 --deploy_modality 参数。
fastdeploy/engine/common_engine.py 启动 worker 时透传 --deploy_modality 参数(common engine 路径)。
fastdeploy/worker/worker_process.py worker CLI 新增 --deploy_modality,并写入 FDConfig.deploy_modality
fastdeploy/worker/input_batch.py text-only 部署下跳过 attn_mask_offsets* buffer 的创建/交换/重置。
fastdeploy/spec_decode/mtp.py 引入 use_attn_mask_offset 开关,控制 MTP 是否使用 attention mask offset。
fastdeploy/worker/gpu_model_runner.py 空输出/空 batch 时在 MTP+EP 场景执行 empty-input forward,避免空跑问题。
fastdeploy/engine/request.py RequestMetrics 增加 decoder 节点时间戳字段与更新方法。
fastdeploy/engine/sched/resource_manager_v1.py D 节点接收 prefilled request 时更新 metrics,并设置 decoder 起始打点。

self.num_main_model_layers = self.model_config.num_hidden_layers
self.local_rank = local_rank
self.device_id = device_id
self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text"
Copy link

Copilot AI Mar 25, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里 self.fd_config.deploy_modalityFDConfig 中是 DeployModality 枚举对象,但此处与字符串 "text" 做比较会恒为 True,导致 text-only 部署下 use_attn_mask_offset 仍被开启,进而访问未分配的 attn_mask_offsets* 张量并触发运行时错误。建议改为与 DeployModality.TEXT(或 DeployModality.TEXT.value)比较,保持与 worker/input_batch.py 的判定一致。

Suggested change
self.use_attn_mask_offset = self.enable_mm and self.fd_config.deploy_modality != "text"
self.use_attn_mask_offset = (
self.enable_mm and self.fd_config.deploy_modality.value != "text"
)

Copilot uses AI. Check for mistakes.
@codecov-commenter
Copy link

Codecov Report

❌ Patch coverage is 64.44444% with 16 lines in your changes missing coverage. Please review.
⚠️ Please upload report for BASE (develop@7a6c287). Learn more about missing BASE report.

Files with missing lines Patch % Lines
fastdeploy/worker/gpu_model_runner.py 25.00% 6 Missing ⚠️
fastdeploy/worker/input_batch.py 40.00% 4 Missing and 2 partials ⚠️
fastdeploy/config.py 83.33% 2 Missing ⚠️
fastdeploy/spec_decode/mtp.py 60.00% 0 Missing and 2 partials ⚠️
Additional details and impacted files
@@            Coverage Diff             @@
##             develop    #7018   +/-   ##
==========================================
  Coverage           ?   73.60%           
==========================================
  Files              ?      399           
  Lines              ?    56432           
  Branches           ?     8930           
==========================================
  Hits               ?    41537           
  Misses             ?    11933           
  Partials           ?     2962           
Flag Coverage Δ
GPU 73.60% <64.44%> (?)

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants