PRISM: Parametrically Refactoring Inference for Speculative Sampling Draft Models (MLSys 2026)

一句话总结:把 speculative decoding 的 draft model 按 draft step 拆成多个 processing module,每步只激活一个 module(类似 MoE 在 auto-regressive 步上的条件计算),参数量随 step 叠加但每步激活不变;在 SGLang 上相比高度优化的 baseline 引擎把 decoding throughput 提升 >2.6×,acceptance length 超 EAGLE-2/HASS/EAGLE-3。

问题

Speculative-Decoding 用小 draft model 预测 verify 大 target model 的 token。EAGLE 系列用 target 的 hidden state 作为额外输入提 acceptance rate。近期趋势是 draft model 越做越大(Table 1:从 0.68% 涨到 19.88% of target size),EAGLE-3、Scylla、Meta 的工作都在做 vertical stacking 或 MoE FFN。核心两难:drafter 越大 acceptance rate 越高,但 draft overhead 也越大,每步 forward pass 成本线性涨。

另一个观察:draft step 难度非均匀——越靠后的 step acceptance rate 急剧下降(Figure 1 LLaMA-3-8B)。统一架构让所有 step 共享参数,浪费了做 step-wise specialization 的机会。

核心方法

PRISM 核心思想:在 auto-regressive step 维度做 conditional computing。把 draft model 拆成 M 个 processing module(每个含 fusion layer + transformer layer),用 surjection 把 K 个 draft step 分配到 M 个 module。每步只激活一个 module。

关键架构细节(Figure 3):

  • Prefill:用 module 处理所有已接受 token,输出 KV cache 和最后隐状态
  • Decode step :用 module ,fuse 上一步 token embedding 和 hidden state,处理并输出下一 token。
  • KV cache 跨 module 共享(不同 step 的 module 虽然参数不同,但 KV cache 顺接)。
  • 与 tree-based verification 和 stochastic sampling 完全兼容。

两大好处

  • Adaptive representation complexity:越靠后的 step 经过越深的级联计算(cascade of different parameter sets),匹配任务难度递增。
  • Decoupling capacity from cost:总参数量可以大幅增加(表达能力上升),但每步激活参数保持不变(推理成本恒定)。

训练(两阶段):

  • Warm-up:只用 1 个 module 训,loss = CEL(draft, ground truth) + λ·MSE(draft hidden, target hidden)。
  • Phase 2:复制 warm-up 的 transformer 权重 M 份,去掉 MSE loss,按 step-wise switch module 训,用 HASS 风格的 3-step context alignment 对抗 exposure bias。

训练效率:因为每个 draft step 的前向/反向只过一个 sub-network,相比 monolithic drafter 同参数量但每步穿透整网,训练 cost 更低。

集成 SGLang:大多数 speculative decoding 论文在 PyTorch 评估,高估加速(缺 CUDA graph、continuous batching)。PRISM 集成进 SGLang,配合 CUDA graph、continuous batching、其他优化。

关键结果

  • Target 模型:LLaMA-2-7B 和 LLaMA-3-8B。
  • Benchmarks:MT-bench、HumanEval、GSM8K、Alpaca、CNN/DM、Natural Questions。
  • NVIDIA A800:LLaMA-3-8B 上 PRISM 的 acceptance length 和 TPS 全面超 Vanilla、Standard、EAGLE-2、HASS,例如 MT-bench T=0 时 AL=4.29 (HASS 3.93)、TPS=201.51 (HASS 180.50)。
  • 相比 SGLang 原生高度优化的 baseline 引擎,decoding throughput 提升 >2.6×
  • Scaling law 验证:PRISM 随训练数据增长(100K → 800K samples)scale 更好,证明”模型容量扩 + 每步激活不变”这条新 scaling 路线有效。

相关