跳到主要内容

一个 IR 统治一切

你正在调试一个慢模型。Profiler 说"kernel X 花了 200ms",但你完全不知道 kernel X 到底做了什么。你翻遍 PyTorch 的 dispatcher,然后是 ATen,然后是 TorchInductor,然后是 Triton IR,最后到达 LLVM IR。五种不同的表示,五种不同的心智模型,五种不同的调试工具。

这就是现代 ML 编译的现实。TensorFlow 的 XLA 也类似:Python → Graph → XLA HLO → MLIR → LLVM IR。每一层都是为了解决真实问题而添加的,但累积的复杂度令人咋舌。

Morok 采用了不同的方案,借鉴自 Tinygrad从张量到机器码,只用一个 IR

┌──────────────────┐ ┌─────────────────┐ ┌───────────────┐
│ TensorFlow │ │ PyTorch │ │ Morok │
├──────────────────┤ ├─────────────────┤ ├───────────────┤
│ Python API │ │ Python API │ │ Rust/Python │
│ TF Graph │ │ FX Graph │ │ ↓ │
│ XLA HLO │ │ Inductor IR │ │ UOp IR │
│ MLIR dialects │ │ Triton IR │ │ ↓ │
│ LLVM IR │ │ LLVM/PTX │ │ Machine code │
│ Machine code │ │ Machine code │ │ │
├──────────────────┤ ├─────────────────┤ ├───────────────┤
│ 5 IRs │ │ 4 IRs │ │ 1 IR │
└──────────────────┘ └─────────────────┘ └───────────────┘

最简单的架构往往胜出。本章解释一个精心设计的 IR 如何替代整个编译器栈。


UOp:通用节点

UOp(微操作)是计算图中的节点。但与其他 IR 中的节点不同,UOp 能表示任何抽象层级的操作——从高层张量 reshape 到底层 CPU 指令。

核心洞察是:与其为"张量操作"、"循环结构"和"内存访问"维护各自独立的 IR,不如把它们放进同一个 enum:

pub enum Op {
// High-level tensor operations
Reshape { src: Arc<UOp>, new_shape: Arc<UOp> },
Permute { src: Arc<UOp>, axes: Vec<usize> },
ReduceAxis { src: Arc<UOp>, reduce_op: ReduceOp, axes: Vec<usize> },

// Loop-level control flow
Range { end: Arc<UOp>, axis_id: AxisId, axis_type: AxisType },
End { computation: Arc<UOp>, ranges: SmallVec<[Arc<UOp>; 4]> },

// Memory operations
Load { buffer: Arc<UOp>, index: Arc<UOp> },
Store { buffer: Arc<UOp>, index: Arc<UOp>, value: Arc<UOp>, ... },

// ALU operations (same as hardware)
Binary(BinaryOp, Arc<UOp>, Arc<UOp>), // Add, Mul, etc.
Unary(UnaryOp, Arc<UOp>), // Sqrt, Exp, etc.
}

这个 enum 包含约 80 个变体,按抽象层级组织:

类别示例表示什么
变换RESHAPE, PERMUTE, EXPAND, PAD张量形状变换
规约REDUCE_AXIS, REDUCE数学聚合运算
控制RANGE, END, IF, BARRIER循环和分支结构
内存LOAD, STORE, INDEX, BUFFER硬件内存访问
ALUADD, MUL, SQRT, EXP, WHERECPU/GPU 指令
高级WMMA, CONTRACT, UNROLLTensor core、向量化

打印 UOp 图时,你会看到它的树形结构:

[42] STORE : Void
├── [10] DEFINE_GLOBAL(0) : Ptr<Float32>
├── [35] INDEX : Ptr<Float32>
│ ├── [10] → (same as above)
│ └── [30] RANGE(axis=0, Reduce) : Index
│ └── [5] CONST(4) : Index
└── [40] REDUCE(Add) : Float32
├── [38] MUL : Float32
│ ├── [36] LOAD : Float32
│ └── [37] LOAD : Float32
└── [30] → (same RANGE as above)

注意到指向"same as above"的箭头了吗?这不仅仅是打印格式——它是一个叫做 hash consing 的基本属性。


Hash Consing:结构共享

在 Morok 中创建同一个表达式两次,你会得到同一个指针。不是值相等——而是同一个内存地址。

let a = UOp::binary(Add, x.clone(), y.clone());
let b = UOp::binary(Add, x.clone(), y.clone());

assert!(Arc::ptr_eq(&a, &b)); // Same pointer!

这通过全局缓存实现。构造 UOp 时先检查是否已有相同的:

pub fn new(op: Op, dtype: DType) -> Arc<Self> {
let key = UOpKey::new(&op, dtype);

// Check cache first
if let Some(existing) = CACHE.get(&key) {
return existing;
}

// Create new and cache it
let uop = Arc::new(UOp { op, dtype, ... });
CACHE.insert(key, uop.clone());
uop
}

这对 ML 工程师意味着什么?

  • 指针相等即语义相等。 检查两个子表达式是否相同,只需比较指针:Arc::ptr_eq(&a, &b)。无需遍历整棵树。

  • 模式匹配是 O(1) 的。 当优化器问"之前见过这个模式吗?"时,指针比较立刻给出答案。

  • 内存高效。 公共子表达式(例如 attention 和梯度图中的共享计算)只存储一次,不会被复制。

  • 线程安全。 不同线程中的相同计算会产生同一个对象——没有同步 bug。

树形打印展示了这一点:当你看到 [10] → (same as above) 时,那不是拷贝——而是从多处引用的同一个节点


显式循环:RANGE 操作

大多数 ML IR 将循环隐藏在操作内部。在 ONNX 中,一个规约看起来是这样:

ReduceSum(data, axes=[1], keepdims=0)

循环在哪里?它是隐式的——藏在运行时 ReduceSum 实现的某处。你看不到它,改不了它,也无法推理它。

Morok 用 RANGE 操作使循环显式化。同样的规约变成:

[REDUCE(Add)]
├── [LOAD]
│ └── [INDEX]
│ ├── [BUFFER]
│ ├── [RANGE(axis=0, Global)] # outer loop (parallelized)
│ │ └── [CONST(128)]
│ └── [RANGE(axis=1, Reduce)] # reduction loop
│ └── [CONST(64)]
└── [RANGE(axis=1, Reduce)] # same RANGE via hash consing

每个 RANGE 都有一个 AxisType,告诉代码生成器如何编译它:

AxisTypeCPUCUDA含义
Loopfor 循环for 循环顺序迭代;rangeify 默认
Global线程池blockIdx外层并行维度
Thread线程池CPU 并行
Warp(N/A)warp/wavefront子组并行
Local(N/A)threadIdx工作组并行
GroupReduce(N/A)共享内存两阶段规约
UpcastSIMD 向量寄存器 tile向量化
Reduce累加器Warp reduce规约维度
Unroll展开展开循环展开

AxisType 层次结构(Loop → Global/Thread → Warp → Local/GroupReduce → Upcast → Reduce → Unroll)映射到硬件执行模型——外层循环的优先级数值更小。AxisType::GlobalRANGE 在 CUDA 中变成 blockIdx.xAxisType::LocalRANGE 变成 threadIdx.x

为什么显式循环重要:

  • 优化是可见的。 你能看到哪些循环会被并行化、哪些会被展开、哪些会使用 SIMD。

  • 调度就是图重写。 改变循环顺序、分块或展开只是模式变换——不需要专门的"调度 pass"。

  • 每个阶段都是同一个 IR。 在张量层面代表"遍历 batch 维度"的 RANGE,就是在生成代码中变成 for (int i = 0; i < N; i++)同一个 RANGE


图重写:统一的变换机制

传统编译器有几十个专用 pass:常量折叠、死代码消除、循环展开、算子融合。每个 pass 都有自己的逻辑、数据结构和 bug。

Morok 只用一种机制:基于模式的图重写

patterns! {
// Identity folding: x + 0 → x
Add[x, @zero] ~> x,

// Constant folding: 3 + 4 → 7
Add(a @const(a_val), b @const(b_val))
=> eval_add(a_val, b_val).map(|r| UOp::const_(a.dtype(), r)),

// Self-folding: x / x → 1
Idiv(x, x) ~> UOp::one(x.dtype()),

// Dead code: if(true) { x } else { y } → x
Where(@true, t, _f) ~> t,
}

这个 DSL 表达力强:

  • [x, y] — 交换律。 尝试两种顺序(用于 ADDMUL 等)
  • (x, y) — 有序。 严格匹配这个顺序。
  • @zero, @one, @true — 语义常量。 对任何 DType 有效。
  • @const(val) — 提取值。 用于编译期计算。
  • x, x — 同一操作数。 检测指针相等。
  • ~> vs => — 不会失败 vs 可能失败的重写。

重写引擎自底向上应用模式直到没有更多匹配:

Original: Add(Mul(x, 1), 0)
After Mul: Add(x, 0) # Mul(x, 1) → x
After Add: x # Add(x, 0) → x

这一种机制处理了:

  • 代数化简 — 常量折叠、恒等消除
  • Rangeify 变换 — 变换操作 → 显式循环
  • Kernel 优化 — 向量化、展开、tensor core
  • 代码生成 — 降级到硬件原语

同样的模式,同样的引擎,不同阶段用不同的模式集。


完整示例:矩阵乘法之旅

追踪 C = A @ B(4×4 矩阵乘法)通过整个流水线。

阶段 1:张量构建

当你写 A.matmul(&B) 时,Morok 构建一个高层 UOp 图:

[REDUCE_AXIS(Add, axes=[2])]
├── [MUL]
│ ├── [EXPAND] # A: [4,4] → [4,4,4]
│ │ └── [BUFFER(A)]
│ └── [EXPAND] # B: [4,4] → [4,4,4]
│ └── [PERMUTE] # transpose for broadcasting
│ └── [BUFFER(B)]

这是纯数学表达:"扩展 A 和 B 以对齐维度,逐元素相乘,沿收缩轴求和。"

阶段 2:Rangeify

Rangeify pass 将变换操作(EXPANDPERMUTE)转换为带有 RANGE 循环的显式索引计算:

[STORE]
├── [DEFINE_GLOBAL(C)]
├── [INDEX]
│ ├── [DEFINE_GLOBAL(C)]
│ ├── [RANGE(i, Global)] # i ∈ [0, 4)
│ │ └── [CONST(4)]
│ └── [RANGE(j, Global)] # j ∈ [0, 4)
│ └── [CONST(4)]
└── [REDUCE(Add)]
├── [MUL]
│ ├── [LOAD(A)]
│ │ └── [INDEX]
│ │ ├── [RANGE(i)] # same i (hash consing)
│ │ └── [RANGE(k, Reduce)]
│ └── [LOAD(B)]
│ └── [INDEX]
│ ├── [RANGE(k)] # same k
│ └── [RANGE(j)] # same j
└── [RANGE(k, Reduce)] # k ∈ [0, 4)
└── [CONST(4)]

现在可以看到循环结构:ijGlobal(并行化),kReduce(累加)。

阶段 3:符号化简

模式重写清理冗余操作,折叠常量,简化索引算术。

阶段 4:代码生成

最终 IR 直接转换为循环:

// GPU kernel (conceptual)
__global__ void matmul(float* C, float* A, float* B) {
int i = blockIdx.x; // from RANGE(i, Global)
int j = blockIdx.y; // from RANGE(j, Global)
float acc = 0.0f;
for (int k = 0; k < 4; k++) { // from RANGE(k, Reduce)
acc += A[i*4 + k] * B[k*4 + j];
}
C[i*4 + j] = acc;
}

关键观察:每个阶段的结构都是可见的。没有神秘的融合 pass 把三个嵌套循环变成面目全非的东西。你在阶段 2 看到的 RANGE 结构,就是阶段 4 中变成循环的那些。


对比:其他 IR 的差异

不同的 IR 做了不同的取舍。以下是对比:

方面ONNXXLA HLOTritonMorok
定位模型交换格式后端优化GPU kernel DSL完整编译
算子~200 高层~100–150 高层Tile 操作~80 多层级
循环模型隐式隐式基于 Tile显式 RANGE
内存纯值纯值 → buffer显式指针显式 LOAD/STORE
优化专用 passMLIR 模式统一重写
目标运行时引擎CPU/GPU/TPU仅 GPUCPU/GPU

ONNX 最大化可移植性。ConvMatMul 等操作隐藏了所有实现细节。很适合模型交换,但看不到的东西没法优化。

XLA HLO 是函数式的、纯的——没有副作用,张量不可变。这使代数优化成为可能,但在代码生成前需要单独的"buffer 分配"阶段。从 HLO 到 LMHLO(基于 buffer)的转换是一个根本性的边界。

Triton 暴露的比 ONNX 多但比 Morok 少。你写"tile 级别"的代码——对数据块的操作——编译器处理线程级细节。内存是显式的(tl.loadtl.store),但 tile 内的并行化是隐式的。

Morok 暴露一切:循环是显式的(RANGE),内存是显式的(LOAD/STORE),并行化是显式的(AxisType)。这意味着需要学更多,但没有东西被隐藏。


为什么这很重要:实际好处

Morok 透明的 IR 对 ML 工程师有实际好处:

调试是直接的。 在任何阶段打印图:

println!("{}", tensor.uop().tree());

你会看到确切存在哪些操作、它们如何连接、计算在哪里发生。没有"kernel X"的谜团。

性能调优有据可循。 查看哪些循环被并行化了:

[RANGE(batch, Global)] # parallelized across GPU blocks
[RANGE(channel, Local)] # parallelized within blocks
[RANGE(pixel, Loop)] # sequential — might be slow!

如果某些东西应该并行但没有,你能看到。

心智模型很简单。 只有一个 IR、一种变换机制、一组操作。你不需要同时学习 XLA HLO MLIR Triton LLVM。只需要 UOp。

优化是可组合的。 想添加自定义重写?加一个模式:

patterns! {
// Your custom optimization
MyPattern(x, y) ~> better_version(x, y),
}

它和常量折叠、融合以及其他所有优化使用同一个引擎。


更深层的洞察

Morok/Tinygrad 证明了编译器的复杂度通常是偶然的,而非本质的。TensorFlow 和 PyTorch 中的多层 IR 栈是有机积累的——每一层都解决了实际问题,但组合起来的系统比任何单个部分都更难理解。

一个设计精良的 IR、一种变换机制和有原则的组合,可以替代数千行专用 pass。这是 Unix 哲学在编译器中的应用:做好一件事,然后组合。

代价是显式性——你会看到其他 IR 隐藏的循环、内存访问和并行化提示。但可见性是特性,不是缺陷。当你的模型运行缓慢时,你想看到原因,而不是寄希望于编译器自己搞定。

这就是 Morok 的赌注:透明的复杂度胜过隐藏的复杂度。