CDLM: Consistency Diffusion Language Models for Faster Sampling (MLSys 2026)

一句话总结:通过 block-wise causal mask fine-tuning + consistency-guided 蒸馏,把 bidirectional Diffusion LM 蒸成 block-causal student,同时解决 KV cache 不兼容和 denoising 步数过多两个瓶颈;在 Dream-7B 和 LLaDA-8B 上实现 3.6-14.5× 端到端 latency 降低,准确率与 baseline 持平。

问题

Diffusion Language Model(DLM)以 parallel generation 作为对 AR LLM 的替代,但开源 DLM 比 AR 还慢,两个根因:

  1. Bidirectional attention 不兼容 KV caching:每步 denoising 都重算所有位置的 attention
  2. Denoising 步数多:高保真度需要与 sequence length 相当的 refinement 步数

现有加速方法分两派:training-free(approximate caching、confidence-threshold 并行采样)和 training-based(fine-tune 成 block-wise causal 架构)。前者精度下降,后者只解决了 caching 没解决 step 数。

核心方法

CDLM 用 fine-tuning 同时拿下两个瓶颈:

1. Block-wise causal attention mask(缓解 KV cache):

  • Teacher:原 DLM,full bidirectional attention
  • Student:block-wise causal mask——attend 到 prompt + 已完成的 block + 当前 decoding block,与 AR 类似可做 block-level KV cache 和 early stopping
  • Block size B=32,generation length

2. Consistency-guided distillation(减少步数): 从 teacher 用 low-confidence remasking 策略(每步 finalize 一个 top-confidence token)离线跑 trajectory 和 hidden state buffer ,学生 fine-tune 三个 loss 组合:

  • Distillation loss:在 state 的 newly unmasked 位置()用 student 预测对齐 teacher logits(通过 lm_head 从 stored hidden 重建),forward KL。给 student multi-token finalization 的主 anchor
  • Consistency loss:在同一 block 内比较 student 在 state 和 block-completion state 上的预测,仅对 still-masked 位置 施加 KL; 是 stop-gradient target(consistency model 风格)。鼓励 student 在 trajectory 上做稳定的多步跳跃
  • DLM loss:保留标准 masked denoising 目标防止能力退化

3. 推理

  • block-wise decoding + block-causal mask + prompt 和已完成 block 的 KV cache
  • 块内 confidence-threshold 并行 finalize(,参考 Fast-dLLM)
  • block 内产生 <endoftext> 就 early stop
  • 不用 inter-block parallelism 等引入 hyperparameter 的 trick

训练成本:Dream-7B 约 8h、LLaDA-8B 约 16h,4× A100 80GB,LoRA attach 到 attention 和 MLP。

关键结果

跨 GSM8K / MATH / HumanEval / MBPP 评估,B=32、、greedy、

CDLM-Dream

  • GSM8K-CoT:latency 11.2× lower(2.1s vs 23.5s),TPS 12.6×,steps 5.8×,score 78.8 vs 79.1
  • HumanEval-Instruct:latency 6.1×,steps 5.2×,score 50.0 vs 48.2
  • MBPP-Instruct:latency 14.5×,TPS 20.9×,score 53.0 vs 51.8

CDLM-LLaDA

  • GSM8K:latency 8.6× (3.3s vs 28.3s),score 73.9 vs 77.1
  • HumanEval:latency 5.9×,steps 7.9×,score 40.2 vs 37.8

综合:refinement steps 减少 3.4-7.9×,latency 降低 3.6-14.5×,且 TPS 超过同等大小的 AR 模型(Qwen2.5-7B、Llama-3.1-8B)。

相关