Mirage: A Multi-Level Superoptimizer for Tensor Programs (OSDI 2025)

一句话总结:schedule-only(TVM/Ansor)与 kernel-only(TASO)优化器无法联合代数+调度+自定义 kernel;Mirage 用 µGraph 统一 kernel/block/thread 三层,抽象表达式剪枝 + LAX 上有限域概率等价验证,自动发现含 FlashAttention 类优化,A100/H100 上 最高 3.3× 于 SOTA,端到端 LLM 0.9–1.9×

问题与动机

DNN tensor program 优化需同时做代数变换、schedule 变换与发明新 kernel(如 FlashAttention 需 700+ 行 Triton)。Halide/TVM/Ansor 固定算法;TASO/PET 固定 kernel 库。跨 GPU 层次(device/shared/register)的联合搜索空间现有方法够不到。

关键观察 / 隐含假设

  • 观察 1:kernel+block 层访存代价主导 thread 层,可 block/thread 混合搜索(block 穷举小图 + thread 规则生成)。
    • 依赖假设:LAX fragment(matmul/conv/div/有限 exp)覆盖主流 DNN 子图。
    • 可能失效场景:非 LAX 算子需切分,可能丢优化机会。
  • 观察 2:抽象表达式剪枝可在保证一定最优性前提下大幅缩搜索空间。
    • 证据强度:中——有理论保证叙述,依赖 rank 等抽象精度。
  • 假设 1:LAX 程序等价性可用有限域随机测试 + PIT 推广,概率误差可任意小。
    • 可能失效场景:浮点语义与有限域差异需部署前额外验证(论文针对 tensor 代数)。

核心方法

µGraph:kernel graph(device mem)→ block graph(shared mem,imap/omap/fmap,for-loop)→ thread graph(register)。

生成:expression-guided 穷举 kernel/block 候选 + abstract expression 剪枝。

验证:probabilistic equivalence(LAX PIT)。

优化:layout、执行序、内存规划后 codegen。

自动发现 FlashAttention/FlashDecoding 及 2.2× 更优变体;RMSNorm+MatMul 融合 1.9× 等。

设计取舍

  • 取舍 1:搜索可达数小时(RMSNorm 表 5),换部署前一次优化;非 interactive compile。
  • 取舍 2:graph-defined kernel 全局↔shared 搬运对轻算子可能亏(nTrans vs TensorRT)。
  • 边界条件:≤5 kernel ops、≤11 block ops(默认);A100/H100。

实验与结果

  • 六微基准(GQA、RMSNorm、LoRA 等):相对最佳 baseline 最高 3.3×
  • GQA:自动 grid 维度满 SM;device memory 访问最多 减少。
  • 端到端 Chameleon/LLaMA-3/LoRA/nGPT:0.9–1.9× latency。
  • 搜索时间:单 LAX 程序最长 ~4 小时(一次性)。

Critical Analysis

论证链条

FlashAttention 手工案例 → µGraph 表达力 → 搜索+验证 → 微基准与 E2E 提升,论证充分。等价性为概率保证,生产需接受极低错误率或加回归测试。

假设压力测试

搜索空间仍随 op 数指数;更大子图可能不可行。与 cuBLAS/cuDNN 黑盒 kernel 的互操作边界需用户切 LAX。H100 FP8/新指令集扩展工作量未讨论。

实验可信度

Baseline 含 TensorRT-LLM、FlashAttention 等强对手;微基准代表 LLM building blocks。E2E 0.9× 个别模型说明并非全胜。

系统性缺陷

编译时间、调试可读性、失败时诊断;论文未讨论 multi-GPU 自动并行。

局限与 Future Work

  • 局限 1:LAX 划分与浮点语义 gap。
  • Future work 1:与 PyTorch 2 compile 路径的深度集成与缓存策略。
  • Future work 2:更大子图(>11 block ops)的近似搜索与验证 tradeoff。

相关