Morok
β οΈ Pre-alpha software. APIs are unstable and may change without notice. Not recommended for production use. π§π
Morok is a Rust-based ML compiler inspired by Tinygrad. It features lazy tensor evaluation with UOp-based IR, pattern-driven optimization, and multi-backend code generation.
Highlights
| Feature | Description |
|---|---|
| Declarative Optimization | patterns! DSL for graph rewrites with Z3-verified correctness |
| Lazy Evaluation | Tensors build computation graphs, compiled only at realize() |
| CUDA Support | Unified memory, D2D copy, LRU buffer caching |
| Provenance Tracking | #[track_caller] traces every UOp to source location |
| 80+ IR Operations | Arithmetic, memory, control flow, WMMA tensor cores |
| 20+ Optimizations | Constant folding, tensor cores, vectorization, loop unrolling |
Quick Example
#![allow(unused)]
fn main() {
use morok_tensor::Tensor;
// Build lazy computation graph
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], &[3])?;
let c = (a + b).sum();
// Compile and execute
let result = c.realize()?;
}
License
MIT
Hands-On: From Tensors to Models
This chapter teaches Morok through progressive examples. Youβll start with basic tensor operations and build up to a working neural network classifier.
What youβll learn:
- Creating and manipulating tensors
- Shape operations (reshape, transpose, broadcast)
- Matrix multiplication
- Building reusable layers
- Composing a complete model
Prerequisites:
- Basic Rust knowledge
- Add
morok_tensorto yourCargo.toml
Key pattern: Morok uses lazy evaluation. Operations build a computation graph without executing. Call realize() to compile and run everything at once.
Example 1: Hello Tensor
Letβs create tensors, perform operations, and get results.
use morok_tensor::Tensor;
fn main() -> Result<(), Box<dyn std::error::Error>> {
// Create tensors from slices
let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
let b = Tensor::from_slice(&[10.0f32, 20.0, 30.0, 40.0]);
// Lazy operations (no execution yet)
let sum = &a + &b;
let scaled = &sum * &Tensor::from_slice(&[0.1f32]);
// Execute and get results
let result = scaled.realize()?;
let data = result.to_ndarray::<f32>()?;
println!("Result: {:?}", data);
// Output: [1.1, 2.2, 3.3, 4.4]
Ok(())
}
Whatβs happening:
-
Tensor::from_slice()creates a tensor from a Rust slice. Thef32suffix tells Rust the element type. -
&a + &bdoesnβt compute anything yet. It returns a newTensorthat represents the addition. The&borrows the tensors so we can reuse them. -
realize()is where the magic happens. Morok:- Analyzes the computation graph
- Fuses operations where possible
- Generates optimized code
- Executes on the target device
-
to_ndarray()extracts the result as anndarray::ArrayDfor inspection.
Try this: Remove the realize() call. The code still runs, but data would be emptyβnothing was computed.
Example 2: Shape Gymnastics
Neural networks constantly reshape data. Letβs master the basics.
#![allow(unused)]
fn main() {
fn shape_example() -> Result<(), Box<dyn std::error::Error>> {
// Create a 1D tensor with 6 elements
let data = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
println!("Original shape: {:?}", data.shape()); // [6]
// Reshape to a 2x3 matrix
let matrix = data.try_reshape(&[2, 3])?;
println!("Matrix shape: {:?}", matrix.shape()); // [2, 3]
// [[1, 2, 3],
// [4, 5, 6]]
// Transpose to 3x2
let transposed = matrix.try_transpose(0, 1)?;
println!("Transposed shape: {:?}", transposed.shape()); // [3, 2]
// [[1, 4],
// [2, 5],
// [3, 6]]
// Broadcasting: add a row vector to every row
// [3, 2] + [1, 2] β [3, 2]
let bias = Tensor::from_slice(&[100.0f32, 200.0])
.try_reshape(&[1, 2])?;
let biased = &transposed + &bias;
let result = biased.realize()?;
println!("{:?}", result.to_ndarray::<f32>()?);
// [[101, 204],
// [102, 205],
// [103, 206]]
Ok(())
}
}
Key operations:
| Operation | What it does |
|---|---|
try_reshape(&[2, 3]) | Change shape (same total elements) |
try_reshape(&[-1, 3]) | Infer dimension from total size |
try_transpose(0, 1) | Swap dimensions 0 and 1 |
try_squeeze(dim) | Remove dimension of size 1 |
try_unsqueeze(dim) | Add dimension of size 1 |
Broadcasting rules (same as NumPy/PyTorch):
- Shapes align from the right
- Each dimension must match or be 1
- Dimensions of size 1 are βstretchedβ to match
[3, 2] + [1, 2] β [3, 2] β (1 broadcasts to 3)
[3, 2] + [2] β [3, 2] β (implicit [1, 2])
[3, 2] + [3] β error β (2 β 3)
Example 3: Matrix Multiply
Matrix multiplication is the workhorse of neural networks. Every layer uses it.
#![allow(unused)]
fn main() {
fn matmul_example() -> Result<(), Box<dyn std::error::Error>> {
// Input: 4 samples, 3 features each β shape [4, 3]
let input = Tensor::from_slice(&[
1.0f32, 2.0, 3.0, // sample 0
4.0, 5.0, 6.0, // sample 1
7.0, 8.0, 9.0, // sample 2
10.0, 11.0, 12.0, // sample 3
]).try_reshape(&[4, 3])?;
// Weights: 3 inputs β 2 outputs β shape [3, 2]
let weights = Tensor::from_slice(&[
0.1f32, 0.2, // feature 0 β outputs
0.3, 0.4, // feature 1 β outputs
0.5, 0.6, // feature 2 β outputs
]).try_reshape(&[3, 2])?;
// Matrix multiply: [4, 3] @ [3, 2] β [4, 2]
let output = input.dot(&weights)?;
let result = output.realize()?;
println!("Output shape: {:?}", result.shape()); // [4, 2]
println!("{:?}", result.to_ndarray::<f32>()?);
// Each row: weighted sum of that sample's features
Ok(())
}
}
Shape rules for dot():
| Left | Right | Result |
|---|---|---|
[M, K] | [K, N] | [M, N] |
[K] | [K, N] | [N] (vector-matrix) |
[M, K] | [K] | [M] (matrix-vector) |
[B, M, K] | [B, K, N] | [B, M, N] (batched) |
The inner dimensions must match (the K). Think of it as: βfor each row of left, dot product with each column of right.β
Example 4: Building a Linear Layer
A linear layer computes y = x @ W.T + b. Letβs build one from scratch.
#![allow(unused)]
fn main() {
use morok_tensor::{Tensor, Error};
struct Linear {
weight: Tensor, // shape: [out_features, in_features]
bias: Tensor, // shape: [out_features]
}
impl Linear {
fn new(in_features: usize, out_features: usize) -> Self {
// Simple initialization (real code would use proper random init)
let weight_data: Vec<f32> = (0..in_features * out_features)
.map(|i| (i as f32 * 0.1).sin() * 0.1)
.collect();
let bias_data = vec![0.0f32; out_features];
Self {
weight: Tensor::from_slice(&weight_data)
.try_reshape(&[out_features as isize, in_features as isize])
.expect("reshape failed"),
bias: Tensor::from_slice(&bias_data),
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor, Error> {
// y = x @ W.T + b
let weight_t = self.weight.try_transpose(0, 1)?;
let out = x.dot(&weight_t)?;
Ok(&out + &self.bias)
}
}
fn linear_example() -> Result<(), Box<dyn std::error::Error>> {
// Create a layer: 4 inputs β 2 outputs
let layer = Linear::new(4, 2);
// Single sample with 4 features
let input = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
// Forward pass
let output = layer.forward(&input)?;
let result = output.realize()?;
println!("Output: {:?}", result.to_ndarray::<f32>()?);
Ok(())
}
}
Why transpose the weights?
PyTorch convention stores weights as [out_features, in_features]. For a layer mapping 4 β 2:
- Weight shape:
[2, 4] - Input shape:
[4]or[batch, 4] - We need:
input @ weight.T=[batch, 4] @ [4, 2]=[batch, 2]
This convention makes it easy to read the weight matrix: row i contains all weights feeding into output i.
Example 5: MNIST Classifier
Letβs build a complete neural network that could classify handwritten digits.
#![allow(unused)]
fn main() {
/// Two-layer neural network for MNIST
/// Architecture: 784 (28Γ28 pixels) β 128 (hidden) β 10 (digits)
struct MnistNet {
fc1: Linear,
fc2: Linear,
}
impl MnistNet {
fn new() -> Self {
Self {
fc1: Linear::new(784, 128),
fc2: Linear::new(128, 10),
}
}
fn forward(&self, x: &Tensor) -> Result<Tensor, Error> {
// Layer 1: linear + ReLU activation
let x = self.fc1.forward(x)?;
let x = x.relu()?;
// Layer 2: linear (no activation β raw logits)
self.fc2.forward(&x)
}
fn predict(&self, x: &Tensor) -> Result<Tensor, Error> {
let logits = self.forward(x)?;
// Convert logits to probabilities
logits.softmax(-1)
}
}
fn mnist_example() -> Result<(), Box<dyn std::error::Error>> {
let model = MnistNet::new();
// Simulate a 28Γ28 grayscale image (flattened to 784)
let fake_image: Vec<f32> = (0..784)
.map(|i| (i as f32) / 784.0)
.collect();
let input = Tensor::from_slice(&fake_image)
.try_reshape(&[1, 784])?; // batch size 1
// Forward pass
let logits = model.forward(&input)?;
let probs = logits.softmax(-1)?;
// Get results
let probs_result = probs.realize()?;
println!("Probabilities: {:?}", probs_result.to_ndarray::<f32>()?);
// Get predicted class
let prediction = logits.argmax(Some(-1))?;
let pred_result = prediction.realize()?;
println!("Predicted digit: {:?}", pred_result.to_ndarray::<i32>()?);
Ok(())
}
}
Key concepts:
-
ReLU activation:
x.relu()returnsmax(0, x). It introduces non-linearityβwithout it, stacking linear layers would just be one big linear layer. -
Logits vs probabilities: The raw output of the last layer (logits) can be any real number.
softmax()converts them to probabilities that sum to 1. -
argmax: Returns the index of the maximum valueβthe predicted class.
-
Batch dimension: We use shape
[1, 784]for a single image. For 32 images, use[32, 784]. The model handles batches automatically.
Example 6: Under the Hood
Want to see what Morok generates? Hereβs how to inspect the IR and generated code.
#![allow(unused)]
fn main() {
fn inspect_compilation() -> Result<(), Box<dyn std::error::Error>> {
let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
let c = &a + &b;
// Print the computation graph (before compilation)
println!("=== IR Graph ===");
println!("{}", c.uop().tree());
// Compile and execute
let result = c.realize()?;
// Inspect generated kernels
println!("\n=== Generated Kernels ===");
for (i, kernel) in result.kernels().iter().enumerate() {
println!("Kernel {}: {}", i, kernel.name);
println!("Backend: {}", kernel.backend);
println!("Code:\n{}\n", kernel.code);
}
Ok(())
}
}
What youβll see:
-
IR Graph: The UOp tree shows operations like
BUFFER,LOAD,ADD,STORE. This is Morokβs intermediate representation before optimization. -
Generated Code: The actual LLVM IR or GPU code that runs. Notice how Morok fuses the loads and add into a single kernelβno intermediate buffers needed.
Debugging tip: If something seems slow or wrong, print the IR tree. Look for:
- Unexpected operations (redundant reshapes, extra copies)
- Missing fusion (separate kernels where one would do)
- Shape mismatches (often the root cause of errors)
Summary
Youβve learned the core patterns for using Morok:
| Task | Code |
|---|---|
| Create tensor | Tensor::from_slice(&[1.0f32, 2.0]) |
| Arithmetic | &a + &b, &a * &b, -&a |
| Reshape | t.try_reshape(&[2, 3])? |
| Transpose | t.try_transpose(0, 1)? |
| Matrix multiply | a.dot(&b)? |
| Activation | t.relu()?, t.softmax(-1)? |
| Execute | t.realize()? |
| Extract data | result.to_ndarray::<f32>()? |
The lazy evaluation pattern:
- Build your computation graph with operations
- Call
realize()once at the end - Morok optimizes and executes everything together
Next steps:
- Op Bestiary β Reference for IR operations
- Execution Pipeline β How compilation works
- Optimization System β Pattern-based rewrites
From Tensor to Machine Code
In most ML frameworks, computation happens immediately. Write a + b in PyTorch and it runs nowβthe GPU crunches numbers before you can even inspect the result. This eager execution is simple to understand, but it leaves optimization opportunities on the table. How can a compiler optimize a computation it hasnβt seen yet?
Morok takes the opposite approach: lazy evaluation. When you write a.try_add(&b)?, nothing computes. Morok builds a graph describing what to compute, not when. The magic happens when you call realize()βthat single method triggers the entire compilation pipeline, from high-level tensor operations down to JIT-compiled machine code.
This chapter traces that journey.
tensor.realize()
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β LAZY GRAPH β
β Tensor ops build UOp DAG (no computation yet) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β RANGEIFY β
β Movement ops β explicit RANGE loops β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β KERNEL SPLITTING β
β Split at STORE boundaries β multiple KERNELs β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β OPTIMIZATION & CODEGEN β
β Heuristics/beam β LLVM IR β JIT compile β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β EXECUTION β
β Parallel kernel launch β result buffer β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Each box is a distinct phase. Letβs walk through them.
Lazy Evaluation: Building the Graph
A Tensor in Morok is surprisingly lightweight:
#![allow(unused)]
fn main() {
pub struct Tensor {
entry: Arc<TensorEntry>, // Computation graph
buffer: Option<Arc<Buffer>>, // Materialized data (if any)
}
}
The entry holds a TensorEntry containing the UOp graphβthe computation this tensor represents. The buffer is optional: lazy tensors donβt have one, only realized tensors do.
Three Ways to Create Tensors
1. Input tensors β buffer allocated immediately:
#![allow(unused)]
fn main() {
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
// `a.buffer` = Some(Arc<Buffer>) with actual data
}
When you create a tensor from data, Morok allocates device memory and copies your bytes. The UOp graph contains a BUFFER node pointing to this allocation.
2. Lazy operations β no buffer, only graph:
#![allow(unused)]
fn main() {
let b = a.try_add(&a)?; // b.buffer = None
let c = b.try_mul(&a)?; // c.buffer = None
}
Arithmetic operations donβt compute anything. They build a UOp graph: Binary(Add, a.uop, a.uop). The tensor exists purely as a description of future work.
3. Movement operations β shares the original buffer:
#![allow(unused)]
fn main() {
let d = a.try_reshape(&[1, 3])?; // d.buffer = same as a.buffer
}
Reshape, permute, and similar operations create new views of existing data. The buffer is shared; only the UOp graph changes to describe the new indexing.
The Global Registry
Morok maintains three global maps (lock-free, thread-safe):
| Map | Key β Value | Purpose |
|---|---|---|
TENSORS | tensor_id β Weak<TensorEntry> | Track all tensors for graph substitution |
BUFFERS | uop_id β Arc<Buffer> | Find buffers during scheduling |
UOP_TO_TENSOR | uop_id β tensor_id | Secondary index for lookups |
This registry enables a critical feature: global graph substitution. When an optimization transforms a UOp, all tensors referencing that UOp automatically see the updated version. No stale references, no manual updates.
Hash Consing in Action
Because UOps use hash consing (content-based deduplication), identical computations share memory:
#![allow(unused)]
fn main() {
let x = a.try_add(&b)?;
let y = a.try_add(&b)?;
// x.uop() and y.uop() point to the SAME Arc<UOp>
}
This matters for caching: when we compile kernels, we cache by UOp ID. Hash consing means identical computations automatically hit the cache, even if constructed separately.
Rangeify: Making Loops Explicit
When you write tensor.reshape([2, 3]).expand([4, 2, 3]).sum(axis=0), those movement operations (reshape, expand) are high-level descriptions. To generate actual loops, we need explicit iteration structure.
Rangeify transforms movement operations into RANGE loops and INDEX arithmetic. The entry point is rangeify() in schedule/src/rangeify/transforms.rs.
The 8-Pass Pipeline
Rangeify isnβt a single transformationβitβs eight coordinated passes:
| Pass | Purpose |
|---|---|
| 1. Range Assignment | Create RANGE UOps for each tensor dimension |
| 2. Early Rewrites | Remove DETACH, clean up trivial RESHAPE |
| 3. Split Large Reductions | Two-stage reduce for huge arrays (ratio > 32768) |
| 4. Core Rangeify | ReduceAxis β REDUCE, bufferization, movement removal |
| 5. Buffer Folding | Constant propagation through buffer expressions |
| 6. Dead Axis Removal | Filter ranges that donβt affect the output |
| 7. Cost-Based Buffer Removal | Inline buffers when profitable (PContig optimization) |
| 8. Reduction Simplification | Lift range-independent code out of reductions |
Each pass uses pattern-based rewriting (see the Pattern-Based Optimization chapter). Patterns fire until no more match, then the next pass begins.
Before and After
Consider this tensor expression:
Before: BUFFER.reshape([2, 3]).expand([4, 2, 3]).sum(axis=0)
After rangeify, movement ops become explicit index computations:
After:
STORE
βββ INDEX[RANGE(0..2), RANGE(0..3)]
βββ REDUCE(Add)
βββ LOAD
β βββ INDEX[RANGE(0..4), RANGE(0..2), RANGE(0..3)]
βββ RANGE(0..4, Reduce)
The EXPAND became a RANGE(0..4) that doesnβt affect the buffer indexβbroadcasting. The RESHAPE became different index arithmetic. The SUM became REDUCE(Add) with the first range marked as Reduce type.
Movement β Index Arithmetic
Each movement operation has a specific transformation:
| Operation | Transformation |
|---|---|
| RESHAPE | Flatten/unflatten index expressions |
| PERMUTE | Reorder dimensions in INDEX |
| EXPAND | Index becomes 0 (or range doesnβt affect index) |
| PAD | WHERE(in_bounds, LOAD, pad_value) |
| SHRINK | Offset adjustment in INDEX |
| FLIP | size - 1 - index |
After rangeify, there are no more movement opsβjust arithmetic operations on indices.
Kernel Splitting: Finding the Boundaries
A computation graph might have multiple outputs, or intermediate values that need materialization. Kernel splitting identifies these boundaries and creates separate kernels.
The entry point is run_kernel_split_pipeline() in schedule/src/rangeify/kernel.rs.
Two-Phase Transformation
Phase 1: BUFFERIZE β STORE
BUFFERIZE nodes mark where values should materialize. Phase 1 converts them to explicit STORE operations:
Before: BUFFERIZE(computation, ranges)
After: END(STORE(buffer, INDEX(...), computation), ranges)
The END wrapper captures which ranges scope this store. Buffers are allocated and assigned IDs during this phase.
Phase 2: STORE β KERNEL
Each STORE becomes its own kernel:
Before: END(STORE(...), ranges)
After: KERNEL(SINK(STORE(...)), ranges, buffer_list)
The KERNEL node wraps everything: the computation (as a SINK), the iteration ranges, and the list of buffers this kernel reads and writes.
Tracking Dependencies
When one kernelβs output feeds another kernelβs input, we need dependency tracking:
fix_assign()maps each buffer_id to the kernel that writes it- When kernel B reads a buffer written by kernel A, B depends on A
resolve_kernel_dependencies()builds the dependency graph
Dependencies appear as AFTER nodes in the IR, ensuring kernels execute in valid order.
Buffer Renumbering
Each kernel sees buffers in a specific order (outputs first, then inputs). renumber_define_globals() remaps buffer IDs to match this ordering:
Original: buffer_3, buffer_1, buffer_7
Kernel view: buffer_0 (output), buffer_1, buffer_2 (inputs)
This simplifies code generationβbuffer N is always argument N.
Schedule Creation: Preparing for Execution
Once kernels are split, we need to schedule them: determine execution order, allocate buffers, and prepare for compilation.
create_schedule() in tensor/src/schedule.rs produces a Vec<ScheduleItem>:
#![allow(unused)]
fn main() {
pub struct ScheduleItem {
pub kernel: Arc<UOp>, // KERNEL wrapper
pub ast: Arc<UOp>, // Inner computation (for codegen)
pub buffers: Vec<Buffer>, // Device buffers
pub dependencies: Vec<u64>, // Producer kernel IDs
pub fixedvars: HashMap<String, i64>, // Bound iteration variables
}
}
Buffer Allocation Strategy
- Input buffers: Already allocated (from
Tensor::from_slice) - Intermediate buffers: Allocated during scheduling (for kernel outputs that feed other kernels)
- Output buffer: Allocated and registered with the final tensor
Parallel Group Analysis
Not all kernels need sequential execution. Independent kernels can run in parallel:
Kernel A (writes buf0)
Kernel B (writes buf1) βββ no dependency βββ can run in parallel
Kernel C (reads buf0, buf1) βββ depends on A and B
The scheduler uses Kahnβs algorithm to find parallel groups:
- Build the kernel dependency DAG
- Find all kernels with no incoming edges β Group 1
- Remove Group 1, repeat β Group 2, etc.
Each groupβs kernels execute in parallel, then the next group starts.
Code Generation: From UOp to LLVM IR
With kernels scheduled, we generate actual code. Morok currently supports the LLVM backend:
| Backend | Compile Speed | Output Quality | Use Case |
|---|---|---|---|
| LLVM | Slower | Highly optimized | Production |
The Renderer trait abstracts code generation:
#![allow(unused)]
fn main() {
pub trait Renderer {
fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel>;
}
}
LLVM CPU Renderer
The LLVM renderer (codegen/src/llvm/cpu/) traverses the UOp graph and emits LLVM IR:
define void @kernel_0(ptr %args, ptr %vars) {
entry:
%buf0 = load ptr, ptr %args
%buf1 = load ptr, ptr getelementptr(ptr, ptr %args, i64 1)
; ... loop nest ...
br label %loop_0
loop_0:
%i = phi i64 [ 0, %entry ], [ %i.next, %loop_0 ]
; ... computation ...
%i.next = add i64 %i, 1
%cond = icmp slt i64 %i.next, 128
br i1 %cond, label %loop_0, label %exit
exit:
ret void
}
The generated kernel takes two arguments:
args: Array of buffer pointersvars: Array of symbolic variable values (for dynamic shapes)
Post-Optimization Passes
Before code generation, 13+ pattern-based passes clean up the IR:
| Pass | Purpose |
|---|---|
pm_add_loads | Wrap INDEX operations in LOAD |
pre_expand | Convert UNROLL/UPCAST ranges to explicit operations |
devectorize | Group contiguous memory accesses |
pm_reduce_devectorize | Handle vector reductions (K-vec, bool, horizontal) |
pm_fma_decomposition | Convert a*b+c to fused multiply-add |
bool_storage_patterns | Convert bool β uint8 for memory operations |
These passes transform the optimized AST into a form suitable for code generation. The result is clean, vectorized code with proper memory access patterns.
Execution: Running the Kernels
Code generation produces LLVM IR strings. Execution involves JIT compilation and kernel launch.
The ExecutionPlan
prepare_execution_plan() builds an ExecutionPlan:
#![allow(unused)]
fn main() {
pub struct ExecutionPlan {
kernels: Vec<PreparedKernel>, // Compiled kernels
parallel_groups: Vec<ParallelGroup>,
buffers: Vec<Buffer>,
output_buffer_idx: usize,
}
}
The plan is reusable: compile once, execute many times with different data.
JIT Compilation
The LLVM runtime (runtime/src/llvm.rs) compiles IR to machine code:
- Parse the LLVM IR string into a module
- Verify the module is well-formed
- Optimize with LLVMβs O3 pass pipeline
- JIT compile to native machine code
- Cache by (AST ID, device) for reuse
#![allow(unused)]
fn main() {
// Simplified JIT flow
let module = Module::parse_ir(context, ir_string)?;
module.verify()?;
pass_manager.run(&module); // O3 optimization
let function = execution_engine.get_function::<KernelFn>(&name)?;
// Cache: (ast_id, device) β function
}
Parallel Execution
With kernels compiled, execution follows the parallel groups:
#![allow(unused)]
fn main() {
for group in &plan.parallel_groups {
if group.kernel_indices.len() == 1 {
// Single kernel: direct call
execute_kernel(&kernels[group.kernel_indices[0]]);
} else {
// Multiple kernels: parallel execution
rayon::scope(|s| {
for &idx in &group.kernel_indices {
s.spawn(|_| execute_kernel(&kernels[idx]));
}
});
}
}
}
Independent kernels run in parallel using Rayonβs work-stealing scheduler.
Kernel Caching
Hash consing makes kernel caching highly effective:
- Key:
(UOp ID, device string) - Storage: Lock-free HashMap (papaya crate)
- Hit rate: High, because identical computations share UOp IDs
When you compute the same expression twice, the second call hits the cacheβno recompilation.
Worked Example: Matrix Multiply
Letβs trace C = A @ B through the entire pipeline. Assume 4Γ4 matrices.
Stage 1: Lazy Graph Construction
#![allow(unused)]
fn main() {
let a = Tensor::from_slice(&a_data, &[4, 4])?; // Input buffer allocated
let b = Tensor::from_slice(&b_data, &[4, 4])?; // Input buffer allocated
let c = a.matmul(&b)?; // Graph built, no computation
}
At this point, c is a lazy tensor with this UOp graph:
REDUCE_AXIS(Add, axis=2)
βββ MUL
βββ EXPAND(A, [4, 4, 4]) β A: [4, 4] β [4, 1, 4] β [4, 4, 4]
βββ EXPAND(B, [4, 4, 4]) β B: [4, 4] β [1, 4, 4] β [4, 4, 4]
Stage 2: Rangeify
Movement ops become explicit loops:
STORE
βββ BUFFER(C)
βββ INDEX[RANGE(i, 0..4), RANGE(j, 0..4)]
βββ REDUCE(Add)
βββ MUL
β βββ LOAD(A)
β β βββ INDEX[RANGE(i), RANGE(k, 0..4, Reduce)]
β βββ LOAD(B)
β βββ INDEX[RANGE(k), RANGE(j)]
βββ RANGE(k, Reduce)
The i and j ranges are output dimensions. The k range is the reduction (contracted) dimension.
Stage 3: Kernel Splitting
Single STORE β single KERNEL:
KERNEL
βββ SINK(STORE(...))
βββ ranges: [i: 0..4, j: 0..4]
βββ buffers: [C (output), A (input), B (input)]
Stage 4: Schedule
One ScheduleItem with:
kernel: The KERNEL UOpast: The inner SINK/STOREbuffers: [C, A, B]dependencies: [] (no prior kernels)
Stage 5: Optimization
Heuristic optimizer applies:
- Vectorization: UPCAST j dimension by 4
- Loop ordering: Ensure good cache behavior
Stage 6: Code Generation
Generated LLVM IR (simplified):
define void @matmul(ptr %args, ptr %vars) {
entry:
%C = load ptr, ptr %args
%A = load ptr, ptr getelementptr(ptr, ptr %args, i64 1)
%B = load ptr, ptr getelementptr(ptr, ptr %args, i64 2)
br label %loop_i
loop_i:
%i = phi i64 [ 0, %entry ], [ %i.next, %loop_i.end ]
br label %loop_j
loop_j:
%j = phi i64 [ 0, %loop_i ], [ %j.next, %loop_k.end ]
%acc = ... ; initialize accumulator
br label %loop_k
loop_k:
%k = phi i64 [ 0, %loop_j ], [ %k.next, %loop_k ]
%a_val = load float, ptr ... ; A[i, k]
%b_val = load float, ptr ... ; B[k, j]
%prod = fmul float %a_val, %b_val
%acc.new = fadd float %acc, %prod
%k.next = add i64 %k, 1
%k.cond = icmp slt i64 %k.next, 4
br i1 %k.cond, label %loop_k, label %loop_k.end
loop_k.end:
store float %acc.new, ptr ... ; C[i, j]
; ... continue j, i loops
}
Stage 7: Execution
- JIT compile the LLVM IR
- Execute:
kernel([C_ptr, A_ptr, B_ptr], []) - Result is in C buffer
Total: one function call, result ready.
Comparison: How Other Frameworks Execute
| Aspect | PyTorch | JAX | TVM | Morok |
|---|---|---|---|---|
| Evaluation | Eager (immediate) | Traced (jit decorator) | Lazy (te.compute) | Lazy (realize) |
| Graph capture | torch.compile | jax.jit trace | Explicit schedule | Implicit via ops |
| Compilation | TorchInductor | XLA backend | Auto-scheduler | Pattern + beam |
| Caching | Per-graph hash | Per-trace | Per-schedule | Per-AST (hash consing) |
| Parallelism | DataParallel/DDP | pmap/pjit | Parallel schedule | Parallel groups |
PyTorch: Eager by default, torch.compile for optimization. TorchInductor generates Triton or C++ code.
JAX: Functional transformations (jit, grad, vmap) trace computations. XLA compiles to optimized kernels.
TVM: Explicit separation of computation and schedule. Auto-scheduler searches for good schedules.
Morok: Fully lazyβnothing executes until realize(). Hash consing provides automatic caching. Pattern-based optimization with optional beam search for production quality.
The Deeper Insight
The pipeline embodies several design principles:
Lazy evaluation enables global optimization. By deferring computation, we see the entire graph before generating code. No local decision limits global optimization.
Explicit loops enable hardware-specific scheduling. Movement ops are convenient abstractions, but GPUs need loops. Rangeify bridges the gap.
Hash consing makes caching automatic. Identical computations share pointers, so cache keys are trivial. No complex graph hashing needed.
Separation of concerns keeps each stage simple. Rangeify doesnβt know about LLVM. Code generation doesnβt know about tensor semantics. Each stage does one thing well.
The result: a compilation pipeline thatβs both powerful and maintainable. From tensor.realize() to machine code, every step is visible, debuggable, and extensible.
Path of the UOp: The 22-Stage Codegen Pipeline
A UOp starts as a high-level tensor expression. By the time it reaches the hardware, it has been transformed through 22 distinct stagesβeach with a specific purpose, each building on the last. This chapter traces that journey.
The pipeline is a proven design for tensor compilation. Understanding it means understanding how tensor expressions become machine code.
How to Read This Chapter
If youβre not a compiler engineer, this chapter might seem intimidating. Hereβs what you need to understand before diving in.
Key Concepts
UOp (Micro-Operation)
- Think of it as a node in a flowchart representing one computation
- Example:
ADD(a, b)means βadd a and bβ
Pattern
- A find-and-replace rule for code structures (not text)
- Example: βIf you see ADD(x, 0), replace with xβ
- Patterns fire repeatedly until no more matches (fixpoint)
Range
- A loop iteration:
RANGE(0..10)means βfor i from 0 to 10β
AxisType
- What kind of loop is this?
- Global: Parallel across GPU blocks / CPU threads
- Local: Parallel within a workgroup
- Reduce: Accumulator (sum, max, etc.)
- Loop: Sequential iteration
Stage
- One transformation pass through the code
- Patterns fire until fixpoint, then move to the next stage
Reading Strategy
- First pass: Read just the βWhat This Doesβ and βWhy This Mattersβ sections
- Second pass: Look at the diagrams and examples
- Third pass (if you want details): Read the pattern descriptions
Questions to Ask
For each stage, ask:
- What does this stage accomplish? (High-level goal)
- Why do we need this stage? (Motivation)
- What would go wrong without it? (Consequences)
Overview
The 22 stages fall into four phases:
Tensor Expression
β
βΌ
βββββββββββββββββββββββββββββββββββββββ
β RANGEIFY (Stages 1-7) β
β Movement ops β Explicit loops β
β β
β [Make iteration explicit, β
β optimize ranges] β
βββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββ
β EXPANDER (Stages 8-10) β
β UNROLL/UPCAST β Explicit vectors β
β β
β [Expand optimization primitives] β
βββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββ
β DEVECTORIZER (Stages 11-15) β
β Vector ops β Scalar code β
β β
β [Lower to hardware-specific ops] β
βββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββ
β LINEARIZER (Stages 16-22) β
β IR β Linear instruction sequence β
β β
β [Serialize to executable code] β
βββββββββββββββββββββββββββββββββββββββ
β
βΌ
Machine Code
Each stage applies pattern-based rewrites. Patterns fire until fixpoint, then the next stage begins.
Phase 1: Rangeify
Goal: Transform high-level movement operations into explicit loop structures and optimize ranges.
Stage 1: Early Movement Ops
Stage at a Glance
Goal: Clean up movement operations before range assignment Key Patterns: Movement on INDEX, movement through wrappers, nested INDEX simplification Impact: Prevents missed optimizations later in the pipeline
What This Does: This stage cleans up movement operations by pushing index manipulations into places where theyβre actually needed. Think of it as organizing your desk before filing papersβmove instructions closer to where the data is used.
Why This Matters: Movement operations (RESHAPE, PERMUTE, etc.) are convenient abstractions, but the hardware needs concrete index calculations. By cleaning them up early, we ensure patterns in later stages can match correctly.
Pattern: pm_mops + pm_syntactic_sugar (bottom-up)
| Pattern | Transformation | Visual | Location |
|---|---|---|---|
| Movement on INDEX | Apply movement to index expressions | INDEX(PERMUTE(arr), [i, j]) β INDEX(arr, [j, i]) | movement_op_patterns() |
| Movement through AFTER | Move RESHAPE through timing wrapper (Tinygrad-specific) | AFTER(RESHAPE(x, arg), [dep1, dep2]) β RESHAPE(AFTER(x, [dep2]), arg) | Tinygrad only |
| Movement through END | Unwrap movement from END wrapper (Tinygrad-specific) | END(RESHAPE(x), ranges) β END(x, ranges) | Tinygrad only |
| Nested INDEX simplification | Remove redundant nested INDEX (Morok) | INDEX(INDEX(ptr, [i]), [i]) β INDEX(ptr, [i]) | movement_op_patterns() |
| Nested INDEX concat | Flatten nested INDEX for PtrDType | INDEX(INDEX(ptr, i), j) β INDEX(ptr, i, j) | pm_syntactic_sugar |
Why bottom-up? Child nodes must be clean before parents can match. Movement ops nest deeply; cleaning from bottom prevents missed patterns.
Note: Tinygrad and Morok have different approaches here. Tinygrad moves movement ops through wrappers (AFTER, END) because it re-applies movement ops during bufferization. Morok removes movement ops entirely by transforming indices during bufferization, so AFTER/END patterns are not needed.
Morok: movement_op_patterns() in rangeify/patterns.rs
Stage 2: Load Collapse
Stage at a Glance
Goal: Eliminate REDUCE operations by detecting range-independent computation Key Patterns: Bounded sum, gated load collapse, general reduce elimination Impact: Converts loop iterations to arithmetic operations
What This Does: Eliminates REDUCE operations by recognizing when the computation can be done without iteration. Uses range-independent computation detection and symbolic simplification.
Why This Matters: Reducing iterations to arithmetic operations eliminates loop overhead. Instead of running a loop 1000 times, compute the answer directly.
Pattern: pm_load_collapse
// Before: Sum with bounds check
sum(1 for k in 0..64 if k >= length)
// After: Compute count directly (NO LOOP!)
count = clamp(64 - length, 0, 64)
The mechanism works by:
- Identifying subexpressions that donβt depend on the REDUCE range
- Creating DEFINE_VAR for those subexpressions (treats as loop-invariant)
- Substituting the range with DEFINE_VAR and running symbolic simplification
- If the simplified expression has no more ranges, the REDUCE is eliminated
Note: WHERE movement through INDEX (pm_move_where_on_load in Stage 8) is a separate optimization that places conditionals before loads to skip memory accesses, but it doesnβt eliminate REDUCE operations.
Morok: pm_load_collapse() in rangeify/patterns.rs
Stage 3: Split Ranges
Stage at a Glance
Goal: Enable better optimization through divmod decomposition Key Patterns: Split ranges with modulo, flatten ranges Impact: Inner ranges can vectorize, outer can parallelize
What This Does: Handles modulo patterns by splitting a range into outer and inner components.
Why This Matters: Splitting ranges is like dividing a large task among team members. If you have 12 items and each person does 4, you get 3 people Γ 4 items. Inner loops (one personβs 4 items) can be fast; outer loops (3 people) can run in parallel.
Pattern: pm_split_ranges + pm_flatten_range
Before: RANGE(end=12) % 4 // One loop with modulo (slow)
β [Split into outer Γ inner]
After: RANGE(end=3) * 4 + RANGE(end=4)
βouter βinner
Parallel Sequential
This enables:
- Inner ranges can vectorize (SIMD)
- Outer ranges can parallelize (GPU blocks / CPU threads)
pm_flatten_range merges nested ranges on REDUCE/STORE/END when beneficial.
Context: Requires dictionary context (ctx={}) to track substitutions at SINK.
Note: The split only applies when end % mod == 0 (divisibility check).
Morok: pm_split_ranges() + pm_flatten_range() in rangeify/transforms.rs
Stage 4: Initial Symbolic
Stage at a Glance
Goal: Simplify expressions using algebra rules Key Patterns: Constant folding, identity removal, div-mod recombine Impact: Eliminates expensive operations, reduces code size
What This Does: Applies 100+ constant folding and algebraic simplification rules.
Why This Matters: Computers are fast at simple math. Dividing and taking remainders are slow operations. This stage uses algebra rules to eliminate slow operations whenever possible.
Pattern: sym + pm_flatten_range
Constant folding:
ADD(CONST(2), CONST(3)) β CONST(5)
MUL(x, CONST(1)) β x
ADD(x, CONST(0)) β x
Div-mod recombination:
(x / c) * c + (x % c) β x
Why? Computes the same value as x but with 3 operations instead of 1. This pattern finds and removes the redundancy (common in stride calculations).
Boolean algebra:
x AND x β x
x OR FALSE β x
NOT(NOT(x)) β x
Additional categories:
- Identity removal (self-folding, redundant operations)
- Comparison simplification
- Cast optimization
- GEP pushing (move address calculations through ALUs)
- Where folding (combine WHERE with same conditions)
- Reduce mul chain (move multiplications outside reduce)
Morok: symbolic_patterns() in symbolic/patterns.rs
Stage 5: Simplify Ranges
Stage at a Glance
Goal: Merge adjacent ranges to reduce loop overhead Key Patterns: Range merging with cost analysis Impact: Fewer loops = less overhead
What This Does: Merges adjacent ranges when profitable.
Why This Matters: Merging ranges is like combining multiple small trips into one big one. Instead of going to the store 4 times for 4 items, go once for all 4 items. Saves the overhead of starting and stopping.
Pattern: pm_simplify_ranges
// Before: two separate ranges
RANGE(0..4), RANGE(0..8)
// After: merged (if compatible)
RANGE(0..32)
Merge criteria:
- Axis types must be compatible (both output, both reduce, etc.)
- REDUCE scope must remain consistent
- Cost-based: Accept only if divmod operation count does not increase
The compiler only merges if it saves operations. Merging might require division/modulo to recalculate indices. If that costs more than it saves, merge is skipped.
Morok: simplify_merge_adjacent() in rangeify/transforms.rs
Stage 6: Split Store (CPU-only)
Stage at a Glance
Goal: Avoid branch misprediction by splitting conditional stores Key Patterns: Split store ranges at comparison boundaries Impact: More predictable CPU execution
What This Does: Splits store ranges at conditional boundaries when there are CMPLT(range, const) comparisons in the storeβs consumer map.
Why This Matters: Branch misprediction slows down CPUs. Instead of one loop with an if statement that the CPU canβt predict, we create two loops without conditionals. Each loop does predictable work, so the CPU stays fast.
Pattern: pm_split_store
// Before: Store with conditional (branch misprediction risk)
for i in 0..100:
if i < 50:
output[i] = data[i]
// After: Two unconditional stores (predictable)
for i in 0..50: // First loop
output[i] = data[i]
for i in 50..100: // Second loop
output[i] = data[i]
The transformation finds constant comparison points in the storeβs consumer map and creates disjoint ranges for each segment.
Skipped for GPU devicesβthey handle conditionals differently.
Morok: pm_split_store() in rangeify/transforms.rs
Stage 7: Apply Opts
Stage at a Glance
Goal: Find optimal combination of vectorization, unrolling, memory usage Key Algorithm: Beam search or heuristics Impact: Can significantly improve performance
What This Does: The optimization searchβeither beam search or heuristicβexplores different combinations of optimization actions.
Why This Matters: The compiler tries different combinations of optimizations (vectorize here? unroll there?) and picks the fastest. Finding the right combination can make code 10x faster.
Function: apply_opts(sink, renderer)
Optimization actions:
| Action | Effect | Hardware Target |
|---|---|---|
| TC | Enable tensor core usage | |
| UPCAST | Vectorize a dimension | |
| LOCAL | Use local/shared memory | |
| UNROLL | Unroll a loop dimension | |
| GROUP | Group operations for cache | |
| GROUPTOP | Group for reduce ops | |
| THREAD | Thread-based parallelism | |
| NOLOCALS | Disable local memory usage | |
| SWAP | Swap range assignments | |
| PADTO | Pad for alignment |
Optimization Search Explained:
The compiler searches for the best combination:
- Heuristic mode (BEAM=0): Fast hand-coded optimization patterns, no compilation
- Beam search (BEAMβ₯1): Compiles and runs candidates to measure actual performance
Optimization Search:
βββ Heuristic mode (BEAM=0): Hand-coded optimizations
βββ Beam search (BEAMβ₯1):
βββ Generate all possible actions (193 combinations)
βββ Apply to all top-K candidates in parallel
βββ Filter based on constraints
βββ Compile and run each candidate β Measure actual time
βββ Pick fastest
Note: NOLOCALS is a constraint that sets dont_use_locals = True, preventing further LOCAL actions and affecting shared memory usage decisions.
Morok: optimizer/mod.rs, optimizer/opts.rs
Phase 2: Expander
Goal: Transform optimization primitives (UNROLL/UPCAST) into explicit operations.
Stage 8: Post-Opt Symbolic
Stage at a Glance
Goal: Symbolic simplification after optimization Key Patterns: WHERE movement, constant folding Impact: Enables better load combining and vectorization
What This Does: Symbolic simplification after optimization, plus WHERE movement.
Why This Matters: WHERE operations are like if statements. This stage moves if checks from after a load to before the load. Hardware can skip loading when the condition is false, saving memory bandwidth.
Pattern: sym + pm_move_where_on_load
// Before: WHERE guards a load
WHERE(valid, LOAD(index), alt)
// After: validity moved to INDEX
LOAD(INDEX(ptr, idx, valid=valid), alt)
Moving validity into INDEX enables better load combining and vectorization.
Note: This pattern only matches when the alternative value is 0. The transformation involves complex clause analysis: duplicate detection, range dependency checks, and data-dependent load verification.
Note: The Morok implementation uses gate= instead of valid= (the Index struct has a gate field). The concept is identical.
Morok: pm_move_where_on_load() in symbolic/patterns.rs
Stage 9: Expander
Stage at a Glance
Goal: Convert UNROLL/UPCAST to explicit operations Key Concepts: UNROLL, CONTRACT, pattern order Impact: Makes vectorization explicit and ready for hardware
What This Does: Transforms UNROLL/UPCAST optimization primitives into explicit operations.
Why This Matters: UPCAST and UNROLL mark intentβwhat we want to do. This stage makes that intent explicit so the hardware can actually do it.
Pattern: sym + pm_pre_expander + pm_group_for_reduce + expander
β οΈ Important: Pattern Precedence
The patterns are combined and run to fixpoint. The order affects which pattern is tried first when multiple could match:
symfirst (symbolic simplification)pm_pre_expandersecond (converts UPCAST/UNROLL ranges)pm_group_for_reducethird (handles GROUP_REDUCE axis)expanderlast (main expansion)
Wrong precedence can cause incorrect vectorization or reduction scoping.
UNROLL and CONTRACT:
UNROLL and CONTRACT work together:
UNROLL: "Take this one thing and make N copies for different positions"
Example: x β [x_0, x_1, x_2, x_3]
CONTRACT: "Take these N things and combine them back"
Example: [a, b, c, d] β one vector containing all four
Together: UPCAST marks intent to vectorize β UNROLL expands β CONTRACT combines.
UPCAST range β VECTORIZE:
// Before: UPCAST marks vectorization intent
RANGE(end=4, UPCAST)
β [pm_pre_expander]
// Step 1: Convert to UNROLL with constant indices
UNROLL(VCONST([0, 1, 2, 3]))
β [expander]
// Step 2: Expand operations with UNROLL sources
// Operations now have unrolled sources
β [CONTRACT or implicit]
// After: explicit VECTORIZE
VECTORIZE(op[0], op[1], op[2], op[3])
UNROLL range β repeated operations:
When we say βoperations duplicated,β it sounds like copy-paste. But thatβs not what happens. The compiler creates a single SIMD instruction that processes all N elements together. Think of a SIMD register as a box holding 4 numbers; adding two boxes adds all 8 numbers at once.
// Before: UPCAST marks vectorization intent
RANGE(end=3, UPCAST)
β [pm_pre_expander]
// Step 1: Convert to UNROLL
UNROLL(VCONST([0, 1, 2]))
β [expander]
// Step 2: Operations expand to handle all positions
// After: operations processed together (not duplicated)
UNROLL([op_at_0, op_at_1, op_at_2])
UNROLL/END/CONTRACT interaction:
Before: END(STORE(...), [RANGE(UPCAST)])
β [pm_pre_expander]
Step 1: END(STORE(...), [UNROLL(VCONST([0,1,2,3]))])
β [expander]
Step 2: END(CONTRACT(STORE(...Γ4)), [])
Broadcast through AFTER/END:
// Broadcast VECTORIZE (all elements identical)
AFTER(VECTORIZE([x, x, x, x]), deps) β VECTORIZE([AFTER(x, deps), AFTER(x, deps), ...])
GROUP_REDUCE Handling (pm_group_for_reduce):
GROUP_REDUCE is a special axis type for tensor core reductions:
// Before: REDUCE with GROUP_REDUCE ranges
REDUCE(src, [range(GROUP_REDUCE)])
β [pm_group_for_reduce]
// After: Shared memory reduction pattern
1. Track upstream LOCAL ranges
2. BUFFERIZE result with group ranges (AddrSpace.LOCAL)
3. INDEX into buffer with transformed ranges
4. Final REDUCE with axes (range_id+100, AxisType.REDUCE)
This enables efficient tensor core accumulation via shared memory.
Morok: expand.rs
Stage 10: Add Local Buffers
Stage at a Glance
Goal: Prepare buffers for fast memory (shared / L1) Key Patterns: Bufferize with locals, extract hints Impact: Frequently-accessed data stays in fast memory
What This Does: Prepares buffers for local memory usage and applies codegen-specific cleanups.
Why This Matters: Local buffers = fast memory close to the compute unit:
- GPU: Shared memory (LDS) β 100x faster than global memory
- CPU: L1 cache β 10x faster than main memory
The compiler moves frequently-accessed data to local buffers, similar to keeping important files on your desktop instead of a network drive.
Pattern: pm_add_buffers_local + rangeify_codegen
| Transform | Purpose |
|---|---|
bufferize_to_store | Convert BUFFERIZE with allow_locals=true |
get_contiguous | Extract optimization hints from CONTIGUOUS |
| NOOP removal | Clean up no-op operations |
| Strip arg from STORE | Remove redundant arguments |
| Fix broadcast dtype | Ensure consistent types in broadcasts |
Morok: rangeify/kernel.rs
Phase 3: Devectorizer
Goal: Lower from hardware-agnostic vectors to hardware-specific instructions.
Stage 11: Remove Reduce
Stage at a Glance
Goal: Convert declarative REDUCE to imperative accumulation Key Patterns: Reduce to accumulator, horizontal reduction Impact: Maps to hardware reduction instructions
What This Does: Converts high-level REDUCE to accumulator pattern.
Why This Matters: A declarative βsum these valuesβ needs to become imperative instructions: initialize accumulator, loop, add each value.
Pattern: pm_reduce + gep_pushing
// Before: declarative reduction
REDUCE(Add, values, range)
// After: imperative accumulation
acc = DEFINE_REG(0.0)
for i in range:
acc = ADD(acc, values[i])
Horizontal reduction:
Before we loop through a reduction dimension, we first combine neighboring values. This creates larger reductions that map better to hardware instructions.
Before: [a, b, c, d, e, f, g, h] // 8 values
β [Horizontal reduction]
Step 1: [a+e, b+f, c+g, d+h] // 4 partial sums
β [Accumulator pattern]
After: acc = acc + (a+e) + (b+f) + (c+g) + (d+h)
GEP pushing pushes GEP (get element pointer) operations through ALUs for better vectorization:
GEP(ADD(ptr_a, ptr_b), idx) β ADD(GEP(ptr_a, idx), GEP(ptr_b, idx))
Why? Enables SIMD on the two GEPs (can be computed in parallel).
WMMA Tensor Core Fusion:
// Fuse tensor core accumulation inline
WMMA(a, b, c) + add β WMMA(a, b, c + add)
This pattern enables efficient FMA-style accumulation on NVIDIA tensor cores.
Morok: devectorize.rs
Stage 12: Add GPU Dims
Stage at a Glance
Goal: Map abstract ranges to GPU thread indices Key Patterns: Range to SPECIAL replacement Impact: Enables parallel execution on GPU
What This Does: Replaces ranges with GPU thread indices.
Why This Matters: GPUs have hard limits: max 1024 threads per block, max 48KB shared memory. If your computation needs 2000 threads, the compiler must split it into multiple blocks. Dimension limiting handles this automatically.
Pattern: pm_add_gpudims
// Before: abstract range
RANGE(end=256, Global)
// After: GPU-specific
SPECIAL(gidx0) // global thread index
Mapping:
| Range Type | GPU Equivalent |
|---|---|
| Global, THREAD | gidx (global index) |
| Local, WARP, GROUP_REDUCE | lidx (local/workgroup index) |
| Reduce | Loop (no mapping) |
Dimension Limiting:
GPUs have hardware limits (e.g., max 1024 threads per block). When ranges exceed these limits, the compiler:
- Groups adjacent dimensions:
[256, 256, 256]with max[256, 256]β[65536, 256] - Splits large dimensions:
[2048]with max[1024]β[2, 1024] - Reconstructs indices via divmod
Store Masking:
Global stores that donβt use all local dimensions are masked:
// If STORE doesn't use lidx1, mask it:
STORE(INDEX(...), value) β STORE(INDEX(..., gate=(lidx1 == 0)), value)
This ensures stores only execute when unused local indices are 0.
Morok: gpudims.rs
Stage 13: Add Loads
Stage at a Glance
Goal: Wrap INDEX operations in explicit LOAD Key Patterns: Add LOAD, remove redundant loads Impact: Makes memory operations explicit for codegen
What This Does: Wraps INDEX operations in explicit LOAD.
Why This Matters: Index operations compute addresses. LOAD actually reads memory. Making this explicit helps the code generator understand what memory accesses are needed.
Pattern: pm_add_loads
// Before: bare index
INDEX(ptr, i)
// After: explicit load
LOAD(INDEX(ptr, i))
Also removes redundant loads from stores (write-only access).
Note: Not all INDEX operations get wrapped in LOAD. Pointer types (already addresses) and image textures (special hardware) use different access methods.
Morok: devectorize.rs
Stage 14: Devectorize
Stage at a Glance
Goal: Convert abstract vectors to match hardware capabilities Key Phases: 4 coordinated passes Impact: Vectors work with actual hardware width
What This Does: Handles the transition from abstract vectors to hardware operations.
Why This Matters: Devectorize uses 4 conceptual phases within a single graph_rewrite:
- Phase 1: Create PTRCAT to group consecutive pointer accesses, devectorize ALU/WMMA/buffers, expand vector INDEX β GEP(PTRCAT)
- Phase 2: Move GEP through LOAD/STORE
- Phase 3: Distribute PTRCAT through LOAD/STORE, creating CAT(LOADs), fix image buffers
- Phase 4: Split CAT(LOADs) into smaller chunks matching hardware width
PTRCAT Construction:
PTRCAT groups consecutive pointer accesses:
- Generate individual indexes for each vector element
- Extract (valid, root_src) β [offsets] mapping
- Group consecutive offsets by validity and source
- Create PTRCAT from grouped pointers
- Return with GEP permutation for correct element order
This reduces memory bus transactions.
Device-Specific Fold Lengths:
| Device | Fold Lengths | Notes |
|---|---|---|
| DSP | 128, 64, 32, 16, 8, 4 | Large vectors for DSP SIMD |
| GPU (float4) | 4, 2 | Standard GPU vectorization |
| GPU (half + ALLOW_HALF8) | 8, 4, 2 | Half precision with env var |
| GPU (AMX) | 16, 8, 4, 2 | Apple AMX support |
| Image | 4 | Fixed for image textures |
| Default | 1 | Scalar fallback |
Environment Variable: DEVECTORIZE
0: Skipdevectorizeonly (keepscorrect_load_store)1: Full devectorization (default)β₯2: Skip bothdevectorizeandcorrect_load_store
Pattern: devectorize + load_store_folding + correct_load_store + load_store_indexing
Split vectorized ALUs:
// If hardware doesn't support vec4 add
ADD(vec4_a, vec4_b) β [ADD(a[0], b[0]), ADD(a[1], b[1]), ...]
Load/store chunk splitting: Match hardware memory width.
Image fixup: Special handling for image tensor buffers.
Morok: devectorize.rs
Stage 15: Lower Index Dtype
Stage at a Glance
Goal: Convert abstract Index type to concrete integers Key Patterns: Operation-specific lowering based on value bounds Impact: Indices use hardware-native integer types (i32 or i64)
What This Does: Converts abstract Index type to concrete integers.
Why This Matters: The Index type is abstractβhardware doesnβt have it. We need to convert to i32 or i64, which the hardware actually supports.
Pattern: pm_lower_index_dtype
// Before: abstract index type
idx: Index
// After: concrete type
idx: i32 // or i64, based on bounds
Operation-Specific Lowering:
Index type lowering is NOT a single castβeach operation type has specific patterns:
| Operation | Before | After |
|---|---|---|
| Binary ops | ADD(Index, Index) | ADD(i32, i32) with casts |
| CONST | CONST(5): Index | CONST(5): i32 |
| WHERE | WHERE(c, Index, Index) | WHERE(c, i32, i32) |
| RANGE | RANGE(end: Index) | RANGE(end: i32) with cast |
| SPECIAL | SPECIAL(gidx) | Always i32 (GPU indices are 32-bit) |
| DEFINE_VAR | DEFINE_VAR: Index | i32 if bounds fit, else i64 |
| VECTORIZE | VECTORIZE(Index...) | Cast each to concrete scalar |
| CAST cleanup | CAST(i32, Index) | Just i32 (remove redundant cast) |
The select_concrete_dtype() function determines i32 vs i64 using vmin/vmax bounds analysis:
dtype = i32 if bounds fit in [-2^31, 2^31-1] else i64
Morok: symbolic/index_lowering.rs
Phase 4: Linearizer
Goal: Convert the DAG to a linear instruction sequence.
Stage 16: Post-Index Symbolic
Stage at a Glance
Goal: Full symbolic simplification after index lowering Key Patterns: All symbolic rules (140+) Impact: Final cleanup before serialization
What This Does: Full symbolic simplification after index lowering.
Why This Matters: Now that indices are concrete integers (i32/i64), arithmetic can fully simplify. This is the last chance to clean up expressions before linearization.
Pattern: symbolic
Includes GEP pushing patternsβmove address calculations through arithmetic:
Before: GEP(ADD(arr_a, arr_b), idx)
β [Push GEP through ADD]
After: ADD(GEP(arr_a, idx), GEP(arr_b, idx))
Why? Enables parallel computation of GEPs and may enable downstream vectorization. (Note: The pattern only applies when GEPβs dtype and ALUβs dtype are NOT pointers.)
Stage 17: Pre-Matcher (Optional)
Stage at a Glance
Goal: Backend-specific patterns before decomposition Key Patterns: Renderer-specific Impact: Hardware-specific optimizations
What This Does: Renderer-specific patterns applied before decomposition.
Why This Matters: Each backend can add its own patterns. For example, DSP backends use this to replace generic patterns with DSP-specific SIMD intrinsics. This allows hardware-specific optimizations without changing the generic pipeline.
Pattern: renderer.pre_matcher
Most backends (CPU, GPU) donβt need this. Only specialized hardware uses it.
Note: Morok does not currently implement this stage. The Renderer trait has only a decompositor() method. This is a future enhancement for DSP and other specialized backends.
Stage 18: Decompositions
Stage at a Glance
Goal: Rewrite operations the target doesnβt support Key Patterns: Power-of-2, transcendental approximations Impact: Maps high-level ops to hardware instructions
What This Does: Late rewrites for operations the target doesnβt support.
Why This Matters: Hardware doesnβt have every operation. For example, most CPUs donβt have a direct sin instruction. We approximate it with operations that do exist (addition, multiplication, etc.).
Pattern: symbolic_simple + get_late_rewrite_patterns
| Pattern | Example | When Used |
|---|---|---|
MOD β AND | x % 8 β x & 7 | Power-of-2 divisor |
MUL β SHL | x * 16 β x << 4 | Power-of-2 multiplier |
DIV β SHR | x // 8 β x >> 3 | Power-of-2 divisor |
FDIV β MUL | x / 2.0 β x * 0.5 | Float constant divisor |
NEG | x * -1 β NEG(x) | When NEG supported |
MULACC | a * b + c β MULACC(a, b, c) | When FMA supported |
| Fast integer division | x // 7 β (x * M) >> S | Non-power-of-2 divisor |
| De Morganβs laws | (!x) & (!y) β !(x | y) | Boolean simplification |
| Comparison negations | !(x < c) β (c-1) < x | Integer comparisons |
Transcendental function approximations (SIN, EXP, LOG, etc.) are implemented via the decompositor() pathway (see ir/src/decompositions/transcendentals.rs).
Morok: optimizer/mod.rs
Stage 19: Final Rewrite
Stage at a Glance
Goal: Prepare for linearization Key Patterns: CONST vectorization, GEP resolution, END splitting Impact: Clean representation ready for linearization
What This Does: Prepare for linearization.
Why This Matters: Some patterns are easier to apply after decomposition. This stage does final cleanup before converting to a linear sequence.
Pattern: pm_decomp + pm_render + extra_matcher + pm_split_ends
CONST vectorization:
// Make vector constants explicit
CONST(1.0) used as vec4 β VECTORIZE(1.0, 1.0, 1.0, 1.0)
CAT to VECTORIZE (via gep_pushing in symbolic):
CAT(a, b, c, d) β VECTORIZE(a, b, c, d)
CAT cannot be rendered directly; explicit VECTORIZE is required for codegen.
GEP resolution: Convert remaining GEP operations.
Split multi-range ENDs:
// Before: END closing multiple ranges
END(op, [range_a, range_b])
// After: nested single ENDs
END(END(op, range_a), range_b)
extra_matcher: Each backend can add its own final patterns. This allows hardware-specific optimizations without changing the generic pipeline.
Morok: devectorize.rs, linearize/mod.rs, optimizer/mod.rs
Stage 20: Add Control Flow
Stage at a Glance
Goal: Build control flow graph and add range dependencies Key Concept: Three relationship types (nested, dependent, independent) Impact: Correct instruction ordering
What This Does: Builds the control flow graph and adds range dependencies.
Why This Matters: Operations must execute in a valid order. If a load uses a RANGEβs value, the RANGE must come first. This stage tracks and enforces these dependencies.
Pattern: pm_add_control_flow (bottom-up)
// Analyze which END operations depend on which
END(computation, [RANGE_A]) and END(other_computation, [RANGE_B]) are siblings
β Creates edge: RANGE_B.src += END(computation)
// Add explicit dependency
RANGE_B waits for RANGE_A to complete
Three relationship types:
| Relationship | Example | Meaning |
|---|---|---|
| Nested | RANGE_A inside RANGE_B | A must complete before B starts |
| Dependent | LOAD_A uses RANGE_A | RANGE_A must precede LOAD_A |
| Independent | RANGE_X and RANGE_Y donβt interact | Can run in parallel |
Bottom-up traversal ensures dependencies flow correctly from leaves to roots.
Morok: schedule/src/linearize/mod.rs
Stage 21: Linearize
Stage at a Glance
Goal: Convert DAG to linear instruction sequence Key Algorithm: Priority-aware topological sort Impact: Valid execution order
What This Does: Converts the DAG to a linear instruction sequence via priority-aware topological sort.
Why This Matters: The graph structure doesnβt specify execution order. We need to flatten it while respecting dependencies. Priorities ensure sensible ordering (definitions before uses, loads before computation, stores after).
Function: linearize(sink)
| Operation | Priority | Why |
|---|---|---|
| DEFINE_GLOBAL | -20 | Arguments must be defined first |
| DEFINE_VAR | -19 | Variables must be defined first |
| DEFINE_LOCAL | -18 | Allocations first |
| DEFINE_REG | -17 | Registers first |
| CONST | -10 | Constants early for reuse |
| LOAD | -1 | Loads before use |
| END | -5 | Closes ranges |
| STORE | +1 | Stores after computation |
| RANGE | +5 | Ranges open before use |
Lower priority = earlier in sequence. This ensures:
- Definitions come first
- Loads happen before computation
- Stores happen last
- Ranges open before their contents, close after
Run_count ordering: Operations are sorted primarily by execution frequency (run_count), then by priority. Operations with lower execution frequency (outside inner loops) are scheduled first, while operations in inner loops (higher run_count) are scheduled later. Example: A CONST executed 100 times appears before a CONST executed 1M times.
run_count Calculation:
run_count = prod(int(r.vmax) + 1 for r in u.ranges)
This computes how many times an operation executes based on its enclosing ranges.
Morok: schedule/src/linearize/mod.rs
Stage 22: Cleanup IF/ENDIF
Stage at a Glance
Goal: Final cleanup of linear instruction list Key Transformation: Gated INDEX β IF/STORE/ENDIF Impact: Handles hardware without predicated stores
What This Does: Final cleanup of the linear instruction list.
Why This Matters: Some hardware (modern GPUs) supports βpredicated storesββwrite to memory only if condition is true. Older hardware doesnβt. For those, we wrap store in an IF statement. This stage ONLY runs when hardware lacks predicated store support.
Pattern: pm_linearize_cleanups (via line_rewrite, not graph_rewrite)
// Gated INDEX in STORE becomes conditional store
STORE(INDEX(ptr, idx, valid=cond), value)
β IF(cond) { STORE(INDEX(ptr, idx), value) } ENDIF
Note: This stage uses line_rewrite instead of graph_rewrite because it operates on the already-linearized instruction list rather than a DAG.
At this point, the instruction list is ready for code generation.
Morok: schedule/src/linearize/mod.rs (predicated stores path)
Worked Example: Tracing Through All 22 Stages
Letβs trace c = a + b (where a, b are [100, 100] tensors) through the pipeline.
Initial Tensor Graph
[ADD]
βββ [BUFFER(a)] : Float32
βββ [BUFFER(b)] : Float32
After Stage 1: Early Movement Ops
(No changeβno movement ops in this example)
After Stage 2: Load Collapse
(No changeβno reductions in this example)
After Stage 3: Split Ranges
(No changeβno modulo operations)
After Stage 4: Initial Symbolic
(No changeβno simplification needed)
After Stage 5: Simplify Ranges
(No changeβno adjacent ranges yet)
After Stage 6: Split Store
(Not applicableβGPU backend)
After Stage 7: Apply Opts
Optimization actions applied:
- UPCAST j dimension by 4 (vectorization)
- LOCAL for input buffers (if beneficial)
After Stage 8: Post-Opt Symbolic
No changesβsymbolic already clean.
After Stage 9: Expander
UPCAST β UNROLL β CONTRACT:
[VECTORIZE]
βββ [ADD]
β βββ [LOAD(a)]
β β βββ [INDEX]
β β βββ [BUFFER(a)]
β β βββ [RANGE(i, Global, 0..100)]
β β βββ [UNROLL(VCONST([0,1,2,3]))] // Converted from RANGE(j, UPCAST)
β βββ [LOAD(b)]
β βββ [INDEX]
β βββ [BUFFER(b)]
β βββ [RANGE(i)] // Same RANGE via hash consing
β βββ [UNROLL(VCONST([0,1,2,3]))] // Same UNROLL via hash consing
After Stage 10: Add Local Buffers
(If LOCAL opt was chosen)
After Stage 11: Remove Reduce
(No changeβno reductions)
After Stage 12: Add GPU Dims
[SPECIAL(gidx0)] : Index // replaces RANGE(i)
After Stage 13: Add Loads
(No changeβloads already present)
After Stage 14: Devectorize
Vector split to match hardware width:
[VECTORIZE] : <4 x Float32>
βββ [ADD(a[0], b[0])]
βββ [ADD(a[1], b[1])]
βββ [ADD(a[2], b[2])]
βββ [ADD(a[3], b[3])]
After Stage 15: Lower Index Dtype
[SPECIAL(gidx0)] : i32 // concrete type
After Stage 16: Post-Index Symbolic
No changes needed.
After Stage 17: Pre-Matcher
(No patterns for standard backends)
After Stage 18: Decompositions
No decompositions neededβall ops supported.
After Stage 19: Final Rewrite
No changes needed.
After Stage 20: Add Control Flow
Dependencies trackedβno issues.
After Stage 21: Linearize
Linear instruction sequence (simplified):
1. DEFINE_GLOBAL(0) // Output buffer c
2. DEFINE_GLOBAL(1) // Input buffer a
3. DEFINE_GLOBAL(2) // Input buffer b
4. RANGE(i, 0..100, Global) // gidx0
5. RANGE(j, 0..25, Loop) // Unrolled /4
6. LOAD(a, i, j*4+0) // Vector load
7. LOAD(b, i, j*4+0) // Vector load
8. ADD(vec_a, vec_b) // Vector add
9. STORE(c, i, j*4+0, result)
10. END(RANGE(j))
11. END(RANGE(i))
After Stage 22: Cleanup IF/ENDIF
No changes neededβno gated stores.
Result: Ready for code generation! The LLVM/CUDA/other backend will compile this to actual machine code.
Pattern Application Strategy
Each stage uses one of two rewrite strategies:
Top-down (default): Process parents before children. Use when transformations create new matchable subterms.
Bottom-up: Process children before parents. Use when child state affects parent matching (stages 1, 20).
Both iterate to fixpointβpatterns fire until no more match.
Debugging the Pipeline
When a kernel produces wrong results, the bug lives in one of these 22 stages. Use environment variables to extract IR at each stage:
# See IR after each transformation
MOROK_DEBUG=ir cargo test failing_test
Quick Reference
| Symptom | Likely Stages | What to Check |
|---|---|---|
| Wrong values in output | 4, 9, 11, 18 | Symbolic simplification, expansion, devectorization |
| Slow performance | 7, 9, 14, 21 | Optimization, expansion, devectorization, linearization |
| Crashes/panics | 11, 12 | Reduce, GPU dims |
| Wrong loop count | 3, 5, 12 | Split ranges, simplify ranges, GPU dims |
| Missing vectorization | 9, 14 | Expander, devectorize |
Common Issues
- Stage 3-4: Range splitting/symbolic may lose constraints
- Stage 9: Expansion order affects vectorization correctness
- Stage 11: Accumulator initialization must match reduction identity
- Stage 14: Hardware width mismatchβcheck vector fold length
- Stage 18: Missing decompositionβcheck supported_ops list for backend
- Stage 21: Priority bugs cause data racesβverify dependencies
Summary
The 22-stage pipeline transforms tensor expressions into machine code through systematic refinement:
- Stages 1-7: Make iteration explicit, optimize ranges
- Stages 8-10: Expand optimization primitives
- Stages 11-15: Lower to hardware-specific operations
- Stages 16-22: Serialize to executable instructions
Each stage has a single responsibility. Each builds on the last. The result: high-level tensor code runs at near-optimal speed on diverse hardware.
Tinygrad vs Morok: Architectural Differences
This chapter describes the βidealβ 22-stage pipeline based on Tinygradβs implementation. Morok now closely follows this design with minimal differences.
Remaining Architectural Differences
| Stage | Tinygrad | Morok | Notes |
|---|---|---|---|
| 1: Early Movement Ops | Moves movement ops through AFTER/END wrappers | Removes movement ops during bufferization | Both approaches achieve functional equivalence; Morokβs is cleaner |
Aligned Stages (Previously Different)
The following stages were aligned with Tinygrad as of this implementation:
| Stage | What Changed |
|---|---|
| 15: Index Dtype Lowering | Morok now has pm_lower_index_dtype() with full pattern coverage: Binary ops, CONST, WHERE, VECTORIZE, SPECIAL, DEFINE_VAR, RANGE, CAST cleanup |
| 18: Decompositions | Added: fast_division_patterns(), pm_div_to_shr(), pm_fdiv_to_mul(), pm_comparison_negations(), De Morganβs laws |
| 19: Final Rewrite | pm_render() moved from codegen to Stage 19 in schedule pipeline |
Tinygrad-Only Patterns
Morok intentionally does not implement these Tinygrad-specific patterns:
| Pattern | Purpose | Why Morok Doesnβt Need It |
|---|---|---|
to_bufferview | Avoid disk buffer copies for DISK/TINYFS devices | Morok doesnβt support DISK/TINYFS; in-memory backends donβt need this |
| AFTER/END movement patterns | Move movement ops through timing wrappers | Morok removes movement ops during bufferization instead |
Morok Enhancements
Morok has some patterns/enhancements not in Tinygrad:
| Enhancement | Location | Purpose |
|---|---|---|
| Nested INDEX flattening with identical indices | movement_op_patterns() | Removes redundant INDEX(INDEX(ptr, [i]), [i]) |
| CAT β VECTORIZE | pm_render | Converts CAT to explicit VECTORIZE (canβt render CAT directly) |
| PTRCAT([x]) unwrap | pm_render | Removes single-element PTRCAT wrappers |
| GEP through CAST/BITCAST | gep_pushing_patterns() | Pushes GEP through type casts for better optimization |
| Image dtype guard | pm_add_loads() | Skips LOAD wrapping for Image dtype (handled in codegen) |
Glossary
| Term | Simple Definition | Example |
|---|---|---|
| Accumulator | Variable holding running total | acc = acc + value (in reduction) |
| Axis | One dimension of a tensor | Shape [100, 200] has 2 axes |
| AxisType | How a loop executes | Global=parallel, Reduce=accumulate |
| Buffer | Allocated memory holding data | A tensorβs data lives in a buffer |
| Bufferize | Store result in memory instead of computing on-demand | Materialize intermediate value |
| CONTRACT | Combine multiple values into one vector | [a, b, c, d] β vec4(a,b,c,d) |
| Devectorize | Split vectors to match hardware | vec8 β vec4, vec4 |
| Divmod | Division and remainder operations | x // 7, x % 7 |
| Fixpoint | When applying patterns no longer changes anything | Patterns fire until fixpoint |
| GEP | Get Element Pointerβcompute address from indices | arr[i][j] β base + i*stride + j |
| Hash consing | Reuse identical expressions | ADD(x, 0) + ADD(x, 0) shares memory |
| Index | Integer type for array indices | i32 or i64, depending on device |
| Load | Read from memory | value = arr[i] |
| Pattern | Find-and-replace rule for code | ADD(x, 0) β x |
| Predicated store | Write to memory conditionally | Write if valid else skip |
| Range | Loop iteration specification | for i in 0..100 |
| Reduction | Combine many values into one | Sum, max, min |
| Store | Write to memory | arr[i] = value |
| Symbolic | Simplify using algebra rules | (x/4)*4 β x (when x%4=0) |
| Tensor core | Hardware for fast matrix multiply | NVIDIA GPUs only |
| Topological sort | Order nodes respecting dependencies | A before B if B uses Aβs result |
| UNROLL | Expand one op into multiple positions | x β [x_0, x_1, x_2, x_3] |
| UPCAST | Mark intent to vectorize | RANGE(0..4, UPCAST) |
| Vectorize | Process multiple values together | SIMD: add 4 numbers at once |
| WHERE | Conditional selection | WHERE(cond, x, y) = x if cond else y |
One IR to Rule Them All
Youβre debugging a slow model. The profiler says βkernel X takes 200msβ but you have no idea what kernel X actually does. You trace through PyTorchβs dispatcher, then ATen, then TorchInductor, then Triton IR, and finally land in LLVM IR. Five different representations, five different mental models, five different debugging tools.
This is the reality of modern ML compilation. TensorFlowβs XLA has a similar story: Python β Graph β XLA HLO β MLIR β LLVM IR. Each layer was added to solve a real problem, but the accumulated complexity is staggering.
Morok takes a different approach, borrowed from Tinygrad: one IR from tensors to machine code.
ββββββββββββββββββββ βββββββββββββββββββ βββββββββββββββββ
β TensorFlow β β PyTorch β β Morok β
ββββββββββββββββββββ€ βββββββββββββββββββ€ βββββββββββββββββ€
β Python API β β Python API β β Rust/Python β
β TF Graph β β FX Graph β β β β
β XLA HLO β β Inductor IR β β UOp IR β
β MLIR dialects β β Triton IR β β β β
β LLVM IR β β LLVM/PTX β β Machine code β
β Machine code β β Machine code β β β
ββββββββββββββββββββ€ βββββββββββββββββββ€ βββββββββββββββββ€
β 5 IRs β β 4 IRs β β 1 IR β
ββββββββββββββββββββ βββββββββββββββββββ βββββββββββββββββ
The simplest architecture often wins. This chapter explains how one carefully designed IR can replace an entire compiler stack.
UOp: The Universal Node
A UOp (micro-operation) is a node in a computation graph. But unlike nodes in other IRs, a UOp can represent operations at any abstraction levelβfrom high-level tensor reshapes down to individual CPU instructions.
Hereβs the key insight: instead of having separate IRs for βtensor operationsβ and βloop structuresβ and βmemory accessesβ, we put them all in one enum:
#![allow(unused)]
fn main() {
pub enum Op {
// High-level tensor operations
Reshape { src: Arc<UOp>, new_shape: Arc<UOp> },
Permute { src: Arc<UOp>, axes: Vec<usize> },
ReduceAxis { src: Arc<UOp>, reduce_op: ReduceOp, axes: Vec<usize> },
// Loop-level control flow
Range { end: Arc<UOp>, axis_id: AxisId, axis_type: AxisType },
End { computation: Arc<UOp>, ranges: SmallVec<[Arc<UOp>; 4]> },
// Memory operations
Load { buffer: Arc<UOp>, index: Arc<UOp> },
Store { buffer: Arc<UOp>, index: Arc<UOp>, value: Arc<UOp>, ... },
// ALU operations (same as hardware)
Binary(BinaryOp, Arc<UOp>, Arc<UOp>), // Add, Mul, etc.
Unary(UnaryOp, Arc<UOp>), // Sqrt, Exp, etc.
}
}
The enum has ~80 variants organized by abstraction level:
| Category | Examples | What It Represents |
|---|---|---|
| Movement | RESHAPE, PERMUTE, EXPAND, PAD | Tensor shape transformations |
| Reduction | REDUCE_AXIS, REDUCE | Mathematical aggregations |
| Control | RANGE, END, IF, BARRIER | Loop and branch structure |
| Memory | LOAD, STORE, INDEX, BUFFER | Hardware memory access |
| ALU | ADD, MUL, SQRT, EXP, WHERE | CPU/GPU instructions |
| Advanced | WMMA, CONTRACT, UNROLL | Tensor cores, vectorization |
When you print a UOp graph, you see its tree structure:
[42] STORE : Void
βββ [10] DEFINE_GLOBAL(0) : Ptr<Float32>
βββ [35] INDEX : Ptr<Float32>
β βββ [10] β (same as above)
β βββ [30] RANGE(axis=0, Reduce) : Index
β βββ [5] CONST(4) : Index
βββ [40] REDUCE(Add) : Float32
βββ [38] MUL : Float32
β βββ [36] LOAD : Float32
β βββ [37] LOAD : Float32
βββ [30] β (same RANGE as above)
Notice the arrows pointing to βsame as aboveβ? Thatβs not just pretty-printingβitβs a fundamental property called hash consing.
Hash Consing: Structural Sharing
When you create the same expression twice in Morok, you get the same pointer. Not equal valuesβthe same memory address.
#![allow(unused)]
fn main() {
let a = UOp::binary(Add, x.clone(), y.clone());
let b = UOp::binary(Add, x.clone(), y.clone());
assert!(Arc::ptr_eq(&a, &b)); // Same pointer!
}
This works through a global cache. When constructing a UOp, we first check if an identical one exists:
#![allow(unused)]
fn main() {
pub fn new(op: Op, dtype: DType) -> Arc<Self> {
let key = UOpKey::new(&op, dtype);
// Check cache first
if let Some(existing) = CACHE.get(&key) {
return existing;
}
// Create new and cache it
let uop = Arc::new(UOp { op, dtype, ... });
CACHE.insert(key, uop.clone());
uop
}
}
Why does this matter for ML engineers?
-
Pointer equality is semantic equality. To check if two subexpressions are identical, just compare pointers:
Arc::ptr_eq(&a, &b). No tree traversal needed. -
Pattern matching is O(1). When the optimizer asks βhave I seen this pattern before?β, pointer comparison gives an instant answer.
-
Memory efficiency. Common subexpressions (think: shared computations in attention, gradient graphs) are stored once, not duplicated.
-
Thread safety. The same computation from different threads produces the same objectβno synchronization bugs.
The tree printout shows this: when you see [10] β (same as above), thatβs not a copyβitβs the same node referenced from multiple places.
Explicit Loops: The RANGE Operation
Most ML IRs hide loops inside operations. In ONNX, a reduction looks like:
ReduceSum(data, axes=[1], keepdims=0)
Whereβs the loop? Itβs implicitβsomewhere inside the runtimeβs implementation of ReduceSum. You canβt see it, canβt modify it, canβt reason about it.
Morok makes loops explicit using RANGE operations. The same reduction becomes:
[REDUCE(Add)]
βββ [LOAD]
β βββ [INDEX]
β βββ [BUFFER]
β βββ [RANGE(axis=0, Global)] # outer loop (parallelized)
β β βββ [CONST(128)]
β βββ [RANGE(axis=1, Reduce)] # reduction loop
β βββ [CONST(64)]
βββ [RANGE(axis=1, Reduce)] # same RANGE via hash consing
Each RANGE has an AxisType that tells the code generator how to compile it:
| AxisType | CPU | CUDA | Meaning |
|---|---|---|---|
| Global | Thread pool | blockIdx | Outer parallel dimension |
| Local | (N/A) | threadIdx | Workgroup parallelism |
| Loop | for loop | for loop | Sequential iteration |
| Reduce | Accumulator | Warp reduce | Reduction dimension |
| Upcast | SIMD vector | Register tile | Vectorization |
| Unroll | Unrolled | Unrolled | Loop unrolling |
The AxisType hierarchy (Global β Local β Loop β Reduce β Upcast β Unroll) maps directly to GPU programming models. A RANGE with AxisType::Global becomes blockIdx.x in CUDA. A RANGE with AxisType::Local becomes threadIdx.x.
Why explicit loops matter:
-
Optimization is visible. You can see which loops will be parallelized, which will be unrolled, which will use SIMD.
-
Scheduling is graph rewriting. Changing loop order, tiling, or unrolling is just a pattern transformationβno special βscheduling passβ.
-
Same IR at every stage. The
RANGEthat represents βiterate over batch dimensionβ at the tensor level is the sameRANGEthat becomesfor (int i = 0; i < N; i++)in generated code.
Graph Rewriting: One Transformation Mechanism
Traditional compilers have dozens of specialized passes: constant folding, dead code elimination, loop unrolling, operator fusion. Each pass has custom logic, custom data structures, custom bugs.
Morok uses one mechanism: pattern-based graph rewriting.
#![allow(unused)]
fn main() {
patterns! {
// Identity folding: x + 0 β x
Add[x, @zero] ~> x,
// Constant folding: 3 + 4 β 7
Add(a @const(a_val), b @const(b_val))
=> eval_add(a_val, b_val).map(|r| UOp::const_(a.dtype(), r)),
// Self-folding: x / x β 1
Idiv(x, x) ~> UOp::one(x.dtype()),
// Dead code: if(true) { x } else { y } β x
Where(@true, t, _f) ~> t,
}
}
The DSL is expressive:
[x, y]β commutative. Try both orderings (forADD,MUL, etc.)(x, y)β ordered. Match exactly this order.@zero,@one,@trueβ semantic constants. Works for any dtype.@const(val)β extract value. For compile-time computation.x, xβ same operand. Detects pointer equality.~>vs=>β infallible vs fallible rewrite.
The rewrite engine applies patterns bottom-up until no more matches:
Original: Add(Mul(x, 1), 0)
After Mul: Add(x, 0) # Mul(x, 1) β x
After Add: x # Add(x, 0) β x
This single mechanism handles:
- Algebraic simplification β constant folding, identity removal
- Rangeify transformation β movement ops β explicit loops
- Kernel optimization β vectorization, unrolling, tensor cores
- Code generation β lowering to hardware primitives
Same patterns, same engine, different pattern sets for each stage.
Worked Example: Matmul Journey
Letβs trace C = A @ B (a 4Γ4 matrix multiply) through the entire pipeline.
Stage 1: Tensor Construction
When you write A.matmul(&B), Morok builds a high-level UOp graph:
[REDUCE_AXIS(Add, axes=[2])]
βββ [MUL]
β βββ [EXPAND] # A: [4,4] β [4,4,4]
β β βββ [BUFFER(A)]
β βββ [EXPAND] # B: [4,4] β [4,4,4]
β βββ [PERMUTE] # transpose for broadcasting
β βββ [BUFFER(B)]
This is pure math: βexpand A and B to align dimensions, multiply elementwise, sum along the contracted axis.β
Stage 2: Rangeify
The rangeify pass converts movement ops (EXPAND, PERMUTE) into explicit index computations with RANGE loops:
[STORE]
βββ [DEFINE_GLOBAL(C)]
βββ [INDEX]
β βββ [DEFINE_GLOBAL(C)]
β βββ [RANGE(i, Global)] # i β [0, 4)
β β βββ [CONST(4)]
β βββ [RANGE(j, Global)] # j β [0, 4)
β βββ [CONST(4)]
βββ [REDUCE(Add)]
βββ [MUL]
β βββ [LOAD(A)]
β β βββ [INDEX]
β β βββ [RANGE(i)] # same i (hash consing)
β β βββ [RANGE(k, Reduce)]
β βββ [LOAD(B)]
β βββ [INDEX]
β βββ [RANGE(k)] # same k
β βββ [RANGE(j)] # same j
βββ [RANGE(k, Reduce)] # k β [0, 4)
βββ [CONST(4)]
Now we see the loop structure: i and j are Global (parallelized), k is Reduce (accumulated).
Stage 3: Symbolic Simplification
Pattern rewrites clean up redundant operations, fold constants, and simplify index arithmetic.
Stage 4: Code Generation
The final IR translates directly to loops:
// GPU kernel (conceptual)
__global__ void matmul(float* C, float* A, float* B) {
int i = blockIdx.x; // from RANGE(i, Global)
int j = blockIdx.y; // from RANGE(j, Global)
float acc = 0.0f;
for (int k = 0; k < 4; k++) { // from RANGE(k, Reduce)
acc += A[i*4 + k] * B[k*4 + j];
}
C[i*4 + j] = acc;
}
The key observation: structure is visible at every stage. Thereβs no magic fusion pass that turns three nested loops into something unrecognizable. The RANGE structure you see in Stage 2 is exactly what becomes loops in Stage 4.
Comparison: How Other IRs Differ
Different IRs make different tradeoffs. Hereβs how they stack up:
| Aspect | ONNX | XLA HLO | Triton | Morok |
|---|---|---|---|---|
| Purpose | Model interchange | Backend optimization | GPU kernel DSL | Full compilation |
| Operators | ~200 high-level | ~100β150 high-level | Tile operations | ~80 multi-level |
| Loop model | Implicit | Implicit | Tile-based | Explicit RANGE |
| Memory | Pure values | Pure values β buffers | Explicit pointers | Explicit LOAD/STORE |
| Optimization | None | Specialized passes | MLIR patterns | Unified rewriting |
| Targets | Runtime engines | CPU/GPU/TPU | GPU only | CPU/GPU |
ONNX maximizes portability. Operations like Conv and MatMul hide all implementation details. Great for model exchange, but you canβt optimize what you canβt see.
XLA HLO is functional and pureβno side effects, immutable tensors. This enables algebraic optimization but requires a separate βbuffer assignmentβ phase before code generation. The transition from HLO to LMHLO (buffer-based) is a fundamental boundary.
Triton exposes more than ONNX but less than Morok. You write βtile-levelβ codeβoperations on blocks of dataβand the compiler handles thread-level details. Explicit memory (tl.load, tl.store) but implicit parallelization within tiles.
Morok exposes everything: loops are explicit (RANGE), memory is explicit (LOAD/STORE), parallelization is explicit (AxisType). This means more to learn, but nothing is hidden.
Why This Matters: Practical Benefits
Morokβs transparent IR has practical benefits for ML engineers:
Debugging is direct. Print the graph at any stage:
#![allow(unused)]
fn main() {
println!("{}", tensor.uop().tree());
}
Youβll see exactly what operations exist, how they connect, and where the computation happens. No βkernel Xβ mysteries.
Performance tuning is informed. See which loops are parallelized:
[RANGE(batch, Global)] # parallelized across GPU blocks
[RANGE(channel, Local)] # parallelized within blocks
[RANGE(pixel, Loop)] # sequential β might be slow!
If something should be parallel but isnβt, you can see it.
The mental model is simple. Thereβs one IR, one transformation mechanism, one set of operations. You donβt need to learn XLA HLO and MLIR and Triton and LLVM. Just UOps.
Optimization is composable. Want a custom rewrite? Add a pattern:
#![allow(unused)]
fn main() {
patterns! {
// Your custom optimization
MyPattern(x, y) ~> better_version(x, y),
}
}
It works with the same engine as constant folding, fusion, and everything else.
The Deeper Insight
Morok/Tinygrad proves that compiler complexity is often accidental, not essential. The multi-layer IR stacks in TensorFlow and PyTorch accumulated organicallyβeach layer solved a real problem, but the combined system is harder to understand than any individual part.
One well-designed IR, one transformation mechanism, and principled composition can replace thousands of lines of specialized passes. Itβs the Unix philosophy applied to compilers: do one thing well, and compose.
The cost is explicitnessβyou see loops, memory accesses, and parallelization hints that other IRs hide. But visibility is a feature, not a bug. When your model is slow, you want to see why, not hope the compiler figures it out.
Thatβs the bet Morok makes: transparent complexity beats hidden complexity.
Pattern-Based Optimization
Open any production ML compiler and youβll find dozens of optimization passes: constant folding, dead code elimination, operator fusion, loop tiling, vectorization, memory layout optimization. Each pass has its own data structures, its own traversal logic, its own bugs.
Morok takes a different approach: one mechanism for everything.
Traditional Compiler: Morok:
βββββββββββββββββββββββββββ βββββββββββββββββββββββββββ
β Constant Folding β β β
β Dead Code Elimination β β patterns! { β
β Loop Unrolling β β Add[x, @zero] ~> xβ
β Operator Fusion β β Mul[x, @zero] ~> 0β
β Vectorization β β // ...more β
β Memory Planning β β } β
β ...20 more passes β β β
βββββββββββββββββββββββββββ β graph_rewrite(...) β
Custom logic each βββββββββββββββββββββββββββ
One mechanism
Every optimization in Morok is expressed as a pattern: βwhen you see this structure, replace it with that structure.β The same graph_rewrite() function applies constant folding, converts movement ops to loops, optimizes memory access patterns, and lowers to hardware primitives.
This chapter explains how pattern-based optimization works and why itβs powerful.
The patterns! DSL
Morok provides a domain-specific language for writing optimization patterns. Hereβs what it looks like:
#![allow(unused)]
fn main() {
patterns! {
// Identity folding: x + 0 β x
Add[x, @zero] ~> |x| x.clone(),
// Constant folding: 3 + 4 β 7
Add(a @const(a_val), b @const(b_val))
=> |a, a_val, b_val| eval_add(a_val, b_val).map(|r| UOp::const_(a.dtype(), r)),
// Self-folding: x / x β 1
Idiv(x, x) ~> |x| UOp::one(x.dtype()),
// Dead code elimination: if(true) { t } else { f } β t
Where(@true, t, _f) ~> |t| t.clone(),
}
}
The macro compiles these patterns into efficient Rust code. Letβs break down the syntax:
| Syntax | Meaning | Example |
|---|---|---|
(x, y) | Ordered. Match in exact order. | Sub(x, @zero) ~> x |
[x, y] | Commutative. Try both orderings. | Add[x, @zero] ~> x |
@zero | Zero constant. Matches 0 or 0.0. | Mul[_, z @ @zero] ~> z |
@one | One constant. Matches 1 or 1.0. | Mul[x, @one] ~> x |
@const(val) | Extract constant. Binds the value. | Add(@const(a), @const(b)) |
x, x | Same operand. Auto-generates ptr_eq check. | Idiv(x, x) ~> UOp::one(...) |
~> | Infallible. Always succeeds, returns Arc<UOp>. | Add[x, @zero] ~> x |
=> | Fallible. May fail, returns Option<Arc<UOp>>. | => eval(...).map(...) |
for op in binary [...] | Template. Generate patterns for multiple ops. | See below |
@context Type | Stateful. Access mutable context in patterns. | See below |
Template Expansion
Instead of writing the same pattern for every binary operation, use a for-loop:
#![allow(unused)]
fn main() {
patterns! {
for op in binary [Add, Mul, Sub, Idiv, Fdiv, Max] {
op(a @const(a_val), b @const(b_val))
=> |a, a_val, b_val| eval_binary(op, a_val, b_val)
.map(|r| UOp::const_(a.dtype(), r))
}
}
}
This expands to six separate patterns at compile timeβone for each operation.
Stateful Patterns
Some optimizations need context (e.g., which kernel weβre in, what ranges are active). Declare a context type:
#![allow(unused)]
fn main() {
patterns! {
@context KernelContext;
ReduceAxis { src } => |reduce, src, ctx| {
ctx.record_reduction(reduce);
transform_reduce(reduce, src, ctx)
}
}
}
The context is passed as the last argument to pattern closures.
How Pattern Matching Works
The patterns! macro generates a SimplifiedPatternMatcher that dispatches patterns in O(1) time.
The OpKey Index
Every UOp has an operation type (Add, Mul, Load, etc.). The #[derive(PatternEnum)] macro generates an OpKey enum that maps operations to hashable keys:
#![allow(unused)]
fn main() {
pub enum OpKey {
Binary(BinaryOp), // Add, Mul, Sub, ...
Unary(UnaryOp), // Neg, Sqrt, Exp, ...
Ternary(TernaryOp), // Where, MulAcc
Const,
Load,
Store,
// ... one variant per operation category
}
}
The Matcher Structure
#![allow(unused)]
fn main() {
pub struct SimplifiedPatternMatcher<C = ()> {
indexed: HashMap<OpKey, Vec<PatternClosure<C>>>, // O(1) lookup
wildcards: Vec<PatternClosure<C>>, // patterns matching any op
}
}
When matching a UOp:
- Extract OpKey from the UOpβs operation
- Lookup in the HashMapβO(1)
- Try each closure until one matches
- Fall back to wildcards if no indexed pattern matches
This is 5-10x faster than scanning all patterns linearly.
Commutative Handling
For patterns like Add[x, @zero], the macro generates code that tries both orderings:
#![allow(unused)]
fn main() {
// Try (x, @zero)
if let Some(result) = try_match_ordered(&children[0], &children[1]) {
return result;
}
// Try (@zero, x)
if let Some(result) = try_match_ordered(&children[1], &children[0]) {
return result;
}
}
Duplicate Detection
When you write Idiv(x, x), the pattern should only match if both operands are the same UOp (pointer equality, not structural equality). The macro automatically generates this check:
#![allow(unused)]
fn main() {
// Generated code for Idiv(x, x)
let x = &children[0];
let x_dup = &children[1];
if !Arc::ptr_eq(x, x_dup) {
return NoMatch;
}
// ... rest of pattern
}
This leverages hash consingβidentical subexpressions share the same pointer.
The Rewrite Engine: Two-Stage Algorithm
Pattern matching alone isnβt enough. Consider this expression:
WHERE(Lt(3, 5), t, f)
To simplify it, we need two steps:
Lt(3, 5)βtrue(constant folding)WHERE(true, t, f)βt(dead code elimination)
But the WHERE pattern wonβt match until its child is simplified. The rewrite engine solves this with a two-stage algorithm.
Stage 0: Pattern Application
#![allow(unused)]
fn main() {
fn rewrite_stage0(&mut self, uop: &Arc<UOp>) -> RewriteResult {
match self.matcher.try_match(uop) {
Some(replacement) => RewriteResult::Rewritten(replacement),
None => RewriteResult::Gate(uop.clone()), // process children
}
}
}
If no pattern matches, return Gateβa signal to process children first.
Stage 1: Source Reconstruction
After children are rewritten, rebuild the node with new children and try patterns again:
#![allow(unused)]
fn main() {
fn rewrite_stage1(&mut self, uop: &Arc<UOp>, new_children: Vec<Arc<UOp>>) {
// Rebuild with optimized children
let rebuilt = uop.with_sources(new_children);
// Try patterns againβmight match now!
match self.matcher.try_match(&rebuilt) {
Some(replacement) => replacement,
None => rebuilt,
}
}
}
The Magic: Cascading Optimizations
Stage 0: WHERE(Lt(3, 5), t, f) β Gate (no match, process children)
βββ Lt(3, 5) β true (constant folding matches!)
Stage 1: WHERE(true, t, f) β t (dead code elimination matches!)
The reconstruction stage re-applies patterns, enabling multi-step optimizations in a single traversal.
Safety Limits
To prevent infinite loops, the engine has limits:
- 1000 iterations per node maximum
- 100,000 iterations total maximum
- Panics with diagnostic info if limits exceeded
In practice, well-formed patterns converge quickly.
The Full Optimization Pipeline
Pattern matching is one part of a larger pipeline. When you call tensor.realize(), hereβs what happens:
Tensor.realize()
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β RANGEIFY β
β Convert movement ops (RESHAPE, PERMUTE, EXPAND) β
β into explicit RANGE loops with INDEX operations β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β KERNEL SPLITTING β
β Split computation graph at STORE boundaries β
β Each STORE becomes a separate kernel β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β FOR EACH KERNEL: β
β β
β 1. Symbolic Simplification (algebraic patterns) β
β β
β 2. Scheduler Creation β
β βββ Convert LOOP β GLOBAL for GPU parallelization β
β β
β 3. Kernel Optimization (heuristic OR beam search) β
β βββ Tensor Cores (WMMA) for matmul β
β βββ Vectorization (UPCAST) β
β βββ Loop Unrolling (UNROLL) β
β βββ GPU Local Memory (LOCAL) β
β βββ Grouped Reductions (GROUP) β
β βββ Threading (THREAD) for CPU β
β β
β 4. Post-Optimization Passes β
β βββ Devectorize (memory coalescing) β
β βββ Expand (UNROLL β vector operations) β
β βββ FMA Decomposition (a*b+c β MulAcc) β
β βββ Bool Storage (cast boolβuint8 for memory) β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β
βΌ
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β CODE GENERATION β
β Render optimized AST to LLVM IR, compile, execute β
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
Each box uses pattern-based rewriting. The difference is which patterns are applied:
- Rangeify: Movement op β BUFFERIZE + INDEX patterns
- Symbolic: Algebraic simplification patterns
- Post-opt: Memory access optimization patterns
Kernel Optimization: Heuristics vs Beam Search
After symbolic simplification, each kernel needs scheduling decisions: how to tile loops, where to parallelize, whether to use tensor cores. Morok offers two strategies.
Heuristics (Default)
The heuristic optimizer applies optimizations in a fixed order:
#![allow(unused)]
fn main() {
pub fn hand_coded_optimizations(scheduler: &mut Scheduler) {
// 1. Tensor cores (if matmul pattern detected)
if let Some(tc) = detect_tensor_core_pattern(scheduler) {
apply_tensor_core(scheduler, tc);
return; // TC handles everything
}
// 2. Grouped reductions (two-stage for large reductions)
apply_grouped_reduction_if_needed(scheduler);
// 3. Vectorization (UPCAST output dimensions)
apply_upcast(scheduler, 4);
// 4. GPU local memory (workgroup dimensions)
apply_local_dims(scheduler);
// 5. CPU threading
apply_threading(scheduler);
}
}
Pros: Fast (~50ms per kernel), predictable, no hardware measurement needed.
Cons: May miss optimization opportunities, fixed heuristics donβt adapt to workload.
Beam Search (Optional)
For production workloads, beam search finds better schedules:
#![allow(unused)]
fn main() {
pub fn beam_search(scheduler: Scheduler, config: BeamConfig) -> Scheduler {
let mut beam = vec![scheduler];
for iteration in 0..config.max_iterations {
let mut candidates = vec![];
for state in &beam {
// Generate all valid next actions
for action in generate_actions(state) {
if let Ok(next) = state.apply(action) {
candidates.push(next);
}
}
}
// Compile and time each candidate
let timed: Vec<_> = candidates.par_iter()
.map(|c| (c, measure_kernel_time(c)))
.collect();
// Keep top K by execution time
beam = timed.into_iter()
.sorted_by_key(|(_, time)| *time)
.take(config.beam_width)
.map(|(c, _)| c)
.collect();
}
beam.into_iter().next().unwrap()
}
}
The action space includes ~500 predefined actions:
UPCAST(axis, amount)β vectorize output dimensionUNROLL(axis, amount)β unroll reduction loopLOCAL(axis, amount)β use GPU shared memoryGROUP(axis, amount)β two-stage reductionTHREAD(axis, amount)β CPU parallelizationSWAP(axis1, axis2)β reorder global dimensions
Pros: Finds near-optimal schedules, adapts to hardware.
Cons: Minutes per kernel (but results are cached by AST hash).
Configuration
# Disable optimization (debugging)
MOROK_NOOPT=1 cargo run
# Enable beam search with width 8
MOROK_BEAM=8 cargo run
Or programmatically:
#![allow(unused)]
fn main() {
let config = OptimizerConfig::builder()
.strategy(OptStrategy::Beam { width: 8 })
.build();
tensor.realize_with(config)?;
}
Comparison: How Other Compilers Optimize
Different ML compilers take different approaches to optimization:
| Aspect | XLA | TVM/Ansor | Triton | Morok |
|---|---|---|---|---|
| Philosophy | Fixed heuristics | Search-based | Programmer-guided | Pattern-based |
| Fusion | Conservative rules | Tile-and-fuse | Block-level | Graph rewriting |
| Auto-tuning | None | Evolutionary + cost model | Grid search | Beam search |
| Tuning cost | 0 | Hours | Minutes | Minutes (cached) |
| Flexibility | Low | High | Medium | High |
| Transparency | Low (C++ passes) | Medium (Python) | Medium (DSL) | High (patterns!) |
XLA β Production Conservative
XLA uses fixed heuristics for fusion decisions. Safe and predictable, but leaves performance on the table. The fusion rules are hard-coded in C++βextending them requires deep compiler knowledge.
TVM/Ansor β Maximum Auto-Tuning
TVM separates what to compute from how to compute it. Ansor uses evolutionary search with a learned cost model to explore the schedule space. Can achieve best-in-class performance, but tuning takes hours per model.
Triton β Programmer-Guided
Triton exposes a Python-like DSL where you write blocked algorithms explicitly. The compiler handles register allocation and memory management. Good balance of control and automation, but requires GPU programming expertise.
Morok β Pattern Composition
Morokβs insight: express optimizations as composable patterns. Each pattern is local and verifiable. Complex optimizations emerge from composition. Beam search adds auto-tuning when needed, with results cached for reuse.
Why This Matters: Practical Benefits
Pattern-based optimization has concrete advantages for developers:
Debugging is direct. Patterns are readable code. Add a println! to any pattern to trace when it fires:
#![allow(unused)]
fn main() {
patterns! {
Add[x, @zero] ~> |x| {
println!("Folding add-zero: {:?}", x);
x.clone()
}
}
}
Extensibility is easy. Adding a custom optimization is two lines:
#![allow(unused)]
fn main() {
patterns! {
// Your domain-specific optimization
MyOp(x, y) if is_special_case(x, y) ~> transform(x, y)
}
}
No need to understand compiler internals, write visitors, or modify pass managers.
Correctness is local. Each pattern is a small theorem: βif this structure appears, replacing it with that structure preserves semantics.β Verify each pattern independently. Composition of correct patterns yields correct programs.
Performance is tunable. O(1) pattern dispatch is fast by default. Enable beam search for production workloads. Cache results by AST hashβtune once, benefit forever.
The Deeper Insight
Pattern matching trades generality for composability.
A general-purpose optimization pass can do anythingβbut thatβs exactly the problem. Itβs hard to verify, hard to extend, hard to compose with other passes. Ordering matters. Interactions are subtle.
A pattern is constrained: it matches a specific structure and produces a specific replacement. But constraints enable composition. Run patterns in any orderβthe result converges to the same fixed point. Add new patterns without breaking existing ones. Delete patterns without cascading failures.
Each pattern is a theorem about semantic equivalence. The rewrite engine is a theorem prover, finding derivations from input to optimized output. Correctness follows from the correctness of individual steps.
This is the Unix philosophy applied to compilers: small, focused tools that compose. Pattern-based optimization wonβt solve every problemβbut for the problems it solves, it solves them elegantly.
Op Bestiary: A Field Guide to UOp Operations
When debugging Morok IR dumps, youβll encounter operations that arenβt obvious from their names. This chapter documents non-trivial operations with signatures, field explanations, and examples.
Whatβs covered: Operations that require explanationβloop control, reductions, memory operations, kernel structure, vectorization, tensor cores.
Whatβs NOT covered: Trivial ALU operations (Add, Mul, Sqrt, etc.) that work exactly as youβd expect.
Loop Control: RANGE and END
RANGE β Loop Scope Opener
#![allow(unused)]
fn main() {
Range {
end: Arc<UOp>, // loop bound (exclusive)
axis_id: AxisId, // identifier for deduplication
axis_type: AxisType, // scheduling behavior
}
}
Fields:
| Field | Type | Purpose |
|---|---|---|
end | Arc<UOp> | Upper bound (exclusive), typically a CONST |
axis_id | AxisId | Unrenumbered(n) before kernel splitting, Renumbered(n) after |
axis_type | AxisType | Determines how the loop is scheduled (see below) |
AxisType Hierarchy:
| Type | Priority | GPU Mapping | Purpose |
|---|---|---|---|
Outer | -2 | β | Kernel boundary marker |
Loop | -1 | for loop | Sequential iteration |
Global | 0 | blockIdx | Grid parallelism |
Thread | 0 | thread pool | CPU parallelism |
Warp | 1 | warp/wavefront | Sub-group parallelism |
Local | 2 | threadIdx | Workgroup parallelism |
GroupReduce | 2 | shared memory | Two-stage reduction |
Upcast | 3 | SIMD | Vectorization |
Reduce | 4 | accumulator | Reduction dimension |
Unroll | 5 | unrolled | Loop unrolling |
Priority determines loop nesting orderβlower values are outer loops.
Example:
RANGE(end=128, axis_id=R0, type=Global)
βββ CONST(128) : Index
END β Loop Scope Closer
#![allow(unused)]
fn main() {
End {
computation: Arc<UOp>, // value computed inside loop
ranges: SmallVec<[Arc<UOp>; 4]>, // ranges being closed
}
}
END closes one or more RANGE scopes and removes them from the active set. Multiple ranges can be closed simultaneously.
Example:
END
βββ STORE(...) β computation
βββ RANGE(R0, Global) β first range closed
βββ RANGE(R1, Local) β second range closed
Reduction: REDUCE vs REDUCE_AXIS
Two operations with similar names serve different purposes.
REDUCE_AXIS β Tensor Dimension Reduction (High-Level)
#![allow(unused)]
fn main() {
ReduceAxis {
src: Arc<UOp>, // input tensor
reduce_op: ReduceOp, // Add, Mul, Max, Min
axes: Vec<usize>, // axes to reduce
}
}
Used before rangeify. Operates on tensor dimensions like NumPyβs .sum(axis=0).
Example:
REDUCE_AXIS(Add, axes=[1])
βββ BUFFER[10, 20] : Float32
This reduces a [10, 20] tensor to [10] by summing along axis 1.
REDUCE β Range Iteration Reduction (Low-Level)
#![allow(unused)]
fn main() {
Reduce {
src: Arc<UOp>, // value to accumulate
ranges: SmallVec<[Arc<UOp>; 4]>, // ranges being reduced
reduce_op: ReduceOp, // Add, Mul, Max, Min
}
}
Used after rangeify. Accumulates values across RANGE iterations and closes the specified ranges.
ReduceOp Variants:
| Op | Identity | Operation | Tinygrad |
|---|---|---|---|
Add | 0 | acc + value | β |
Mul | 1 | acc * value | β |
Max | -β | max(acc, value) | β |
Min | +β | min(acc, value) | Morok-only |
Compatibility: Tinygradβs spec restricts REDUCE_AXIS to
{Add, Mul, Max}. Morok extends this withMin.
Example:
REDUCE(Add)
βββ MUL β value to accumulate
β βββ LOAD(A, ...)
β βββ LOAD(B, ...)
βββ RANGE(R2, Reduce) β range being reduced
βββ CONST(64)
ALLREDUCE β Cross-Device Reduction
#![allow(unused)]
fn main() {
AllReduce {
src: Arc<UOp>, // local partial result
device: Arc<UOp>, // device specification
reduce_op: ReduceOp, // reduction operation
}
}
Performs distributed reduction across multiple devices. Used for multi-GPU training.
Buffer Operations
BUFFER β Buffer Declaration
#![allow(unused)]
fn main() {
Buffer {
unique: Arc<UOp>, // UNIQUE op for identity
device: Arc<UOp>, // DEVICE op
size: usize, // total element count
}
}
Declares a buffer for tensor storage. The unique field ensures distinct buffers even with identical size/device.
BUFFERIZE β Materialization Marker
#![allow(unused)]
fn main() {
Bufferize {
compute: Arc<UOp>, // computation to materialize
ranges: SmallVec<[Arc<UOp>; 4]>, // output dimensions
opts: BufferizeOpts, // address space, device
}
}
Marks where computation should materialize to memory. Triggers kernel splitting.
BufferizeOpts:
| Field | Type | Purpose |
|---|---|---|
device | Option<DeviceSpec> | Target device, None for local |
addrspace | AddrSpace | Global (device) or Local (shared) |
Example:
BUFFERIZE(opts={addrspace=Global})
βββ REDUCE(Add, ...) β computation
βββ RANGE(R0, Global) β output dim 0
βββ RANGE(R1, Global) β output dim 1
INDEX β Multi-Dimensional Buffer Access
#![allow(unused)]
fn main() {
Index {
buffer: Arc<UOp>, // BUFFER or DEFINE_GLOBAL
indices: SmallVec<[Arc<UOp>; 4]>, // index per dimension
gate: Option<Arc<UOp>>, // optional predicate
}
}
Computes memory address from multi-dimensional indices. Returns element dtype (not pointer).
Example:
INDEX : Float32
βββ DEFINE_GLOBAL(0)
βββ RANGE(R0, Global) β index for dim 0
βββ RANGE(R1, Loop) β index for dim 1
βββ MUL(...) β index for dim 2
POINTER_INDEX β Low-Level Pointer Arithmetic
#![allow(unused)]
fn main() {
PointerIndex {
ptr: Arc<UOp>, // base pointer
offset: Arc<UOp>, // byte offset
}
}
Direct pointer arithmetic. Used after linearization when indices are flattened.
Compatibility: Tinygrad uses
INDEXwith aptr=Trueflag instead of a separate operation.
LOAD β Memory Read
#![allow(unused)]
fn main() {
Load {
buffer: Arc<UOp>, // buffer or pointer
index: Arc<UOp>, // INDEX op
}
}
Read value from buffer at index. For gated loads, use an INDEX with a gate (INDEX has an optional gate field).
Example:
LOAD : Float32
βββ DEFINE_GLOBAL(1)
βββ INDEX
βββ DEFINE_GLOBAL(1)
βββ RANGE(R0)
βββ RANGE(R2)
STORE β Memory Write
#![allow(unused)]
fn main() {
Store {
buffer: Arc<UOp>, // output buffer
index: Arc<UOp>, // INDEX op
value: Arc<UOp>, // value to write
ranges: SmallVec<[Arc<UOp>; 4]>, // ranges being closed
}
}
Write value to buffer. STORE closes the specified ranges, which represent output iteration dimensions. The ranges field is used for output upcasting: when a Range(Upcast) is included, it becomes UNROLL during expansion, then contracted via CONTRACT.
For gated stores, use an INDEX with a gate (INDEX has an optional gate field).
Compatibility: Morokβs STORE has an explicit
indexfield (sources: buffer=0, index=1, value=2, ranges=3+). Tinygradβs STORE combines buffer and value differently (range_start=2).
Example:
STORE
βββ DEFINE_GLOBAL(0) β output buffer
βββ INDEX[R0, R1] β write address
βββ REDUCE(Add, ...) β value
βββ RANGE(R0, Global) β output dim 0 (closed)
βββ RANGE(R1, Global) β output dim 1 (closed)
Kernel Structure
KERNEL β Kernel Wrapper
#![allow(unused)]
fn main() {
Kernel {
sources: SmallVec<[Arc<UOp>; 4]>, // arguments
ast: Arc<UOp>, // computation (usually SINK)
}
}
Wraps a complete kernel for code generation. Sources are kernel arguments (DefineGlobal, DefineLocal, DefineVar).
Example:
KERNEL
βββ DEFINE_GLOBAL(0) β output buffer arg
βββ DEFINE_GLOBAL(1) β input A arg
βββ DEFINE_GLOBAL(2) β input B arg
βββ SINK β computation
βββ STORE(...)
SINK β Multiple Root Collector
#![allow(unused)]
fn main() {
Sink {
sources: SmallVec<[Arc<UOp>; 4]>,
}
}
Collects multiple outputs into a single root. Every kernelβs ast is typically a SINK containing STORE operations.
Example:
SINK
βββ STORE(output_0, ...)
βββ STORE(output_1, ...)
βββ STORE(output_2, ...)
AFTER β Dependency Marker
#![allow(unused)]
fn main() {
After {
passthrough: Arc<UOp>, // value that flows through
deps: SmallVec<[Arc<UOp>; 4]>, // operations that must complete
}
}
Expresses execution dependencies between kernels without data dependency. The passthrough value is returned unchanged, but only after all deps complete.
Example:
SINK
βββ AFTER
β βββ DEFINE_GLOBAL(0) β passthrough (buffer reference)
β βββ KERNEL(...) β must complete first
βββ KERNEL(...) β can use buffer after AFTER
BARRIER β Synchronization Fence
#![allow(unused)]
fn main() {
Barrier {
src: Arc<UOp>, // value passing through
deps: SmallVec<[Arc<UOp>; 4]>, // operations to wait for
}
}
GPU workgroup synchronization. Ensures all threads in a workgroup reach the barrier before continuing.
Vector Operations
VECTORIZE β Create Vector from Scalars
#![allow(unused)]
fn main() {
Vectorize {
elements: SmallVec<[Arc<UOp>; 4]>,
}
}
Combines N scalar values into a vector of size N. All elements must have the same base dtype.
Example:
VECTORIZE : <4 x Float32>
βββ CONST(1.0)
βββ CONST(2.0)
βββ CONST(3.0)
βββ CONST(4.0)
GEP β Get Element Pointer (Vector Extract)
#![allow(unused)]
fn main() {
Gep {
vector: Arc<UOp>, // source vector
indices: Vec<usize>, // positions to extract
}
}
Extracts elements from a vector:
- Single index β scalar
- Multiple indices β smaller vector
Example:
GEP([0, 2]) : <2 x Float32>
βββ VECTORIZE : <4 x Float32>
βββ ...
VConst β Vector Constant
#![allow(unused)]
fn main() {
VConst {
values: Vec<ConstValue>,
}
}
Vector of compile-time constants. More efficient than VECTORIZE of CONST nodes.
CAT β Concatenate Vectors
#![allow(unused)]
fn main() {
Cat {
sources: SmallVec<[Arc<UOp>; 4]>,
}
}
Concatenates vectors into a larger vector. Output vcount = sum of input vcounts.
Example:
CAT : <8 x Float32>
βββ VECTORIZE : <4 x Float32>
βββ VECTORIZE : <4 x Float32>
PtrCat β Concatenate Pointers
#![allow(unused)]
fn main() {
PtrCat {
sources: SmallVec<[Arc<UOp>; 4]>,
}
}
Groups memory accesses for vectorized load/store. Used by the devectorizer pass.
Expansion: UNROLL and CONTRACT
UNROLL β Expand Computation Across Iterations
#![allow(unused)]
fn main() {
Unroll {
src: Arc<UOp>, // computation to expand
unroll_axes: Vec<(usize, usize)>, // (axis_index, factor) pairs
}
}
Creates multiple versions of computation for different iteration values. Used for loop unrolling optimization.
Example: UNROLL(unroll_axes=[(0, 4)]) expands computation 4 times with different index values.
CONTRACT β Collapse Unrolled Values to Vector
#![allow(unused)]
fn main() {
Contract {
src: Arc<UOp>, // unrolled computation
upcast_ranges: Vec<(usize, usize)>, // (axis_index, factor) pairs
}
}
The inverse of UNROLLβcollects expanded scalar values into a vector. Output vector size = product of factors.
Example:
CONTRACT(upcast_ranges=[(0, 4)]) : <4 x Float32>
βββ UNROLL(unroll_axes=[(0, 4)])
βββ LOAD(...)
This pattern vectorizes a load: expand 4 iterations, then pack results into a 4-element vector.
Tensor Cores: WMMA
WMMA β Warp Matrix Multiply-Accumulate
#![allow(unused)]
fn main() {
Wmma {
a: Arc<UOp>, // matrix A fragment
b: Arc<UOp>, // matrix B fragment
c: Arc<UOp>, // accumulator C fragment
metadata: WmmaMetadata, // hardware configuration
}
}
Hardware tensor core operation: D = A Γ B + C. Requires specific matrix shapes and data layouts.
WmmaMetadata Fields:
| Field | Type | Purpose |
|---|---|---|
name | String | Instruction name (e.g., "__hmma...") |
dims | (N, M, K) | Matrix dimensions (e.g., (16, 16, 16)) |
dtype_in | DType | Input matrix precision (e.g., Float16) |
dtype_out | DType | Output precision (e.g., Float32) |
device | String | Target device string |
threads | usize | Threads per warp (typically 32) |
upcast_axes | Vec<(usize, usize)> | Vectorization for output |
reduce_axes | Vec<(usize, usize)> | Contraction axes |
Example:
WMMA(dims=(16, 16, 16), dtype_in=Float16, dtype_out=Float32)
βββ A fragment : <8 x Float16>
βββ B fragment : <8 x Float16>
βββ C accumulator : <8 x Float32>
Control Flow
IF / ENDIF β Conditional Execution
#![allow(unused)]
fn main() {
If {
condition: Arc<UOp>, // boolean predicate
body: SmallVec<[Arc<UOp>; 4]>, // operations to execute
}
EndIf {
if_op: Arc<UOp>, // corresponding IF op
}
}
Execute body only when condition is true. Used for boundary checks and sparse operations.
Example:
IF
βββ LT(idx, bound) β condition (src[0])
βββ STORE(...) β body[0]
βββ STORE(...) β body[1]
ENDIF
βββ IF(...) β references IF op
Definition Operations
DEFINE_GLOBAL β Device Memory Argument
#![allow(unused)]
fn main() {
DefineGlobal(usize) // argument index
}
Kernel argument for device (global) memory. Index refers to position in kernel argument list.
DEFINE_LOCAL β Shared Memory Allocation
#![allow(unused)]
fn main() {
DefineLocal(usize) // local memory index
}
GPU shared memory (LDS) allocation. Visible within a workgroup.
DEFINE_VAR β Symbolic Runtime Variable
#![allow(unused)]
fn main() {
DefineVar {
name: String, // variable name
min_val: i64, // minimum bound
max_val: i64, // maximum bound
}
}
Runtime variable with known bounds. Used for dynamic shapes where bounds are known.
Example:
DEFINE_VAR(name="batch_size", min=1, max=128) : Index
DEFINE_REG β Register Allocation
#![allow(unused)]
fn main() {
DefineReg {
size: usize, // register size
}
}
Allocates a register for intermediate storage. Used in code generation.
BIND β Variable Binding
#![allow(unused)]
fn main() {
Bind {
var: Arc<UOp>, // DEFINE_VAR
value: Arc<UOp>, // concrete value
}
}
Binds a symbolic variable to a concrete value at runtime.
Special Operations
SPECIAL β Hardware-Provided Values
#![allow(unused)]
fn main() {
Special {
end: Arc<UOp>, // upper bound for this dimension
name: String, // e.g., "blockIdx.x", "threadIdx.y"
}
}
Accesses hardware-provided values (thread/block indices). Not a loopβthe hardware provides the value directly.
Example:
SPECIAL(name="blockIdx.x", end=128) : Index
βββ CONST(128)
UNIQUE β Identity Marker
#![allow(unused)]
fn main() {
Unique(usize) // unique identifier
}
Creates a unique identity for buffer disambiguation. Two buffers with different UNIQUE values are distinct even if otherwise identical.
DEVICE β Device Specification
#![allow(unused)]
fn main() {
Device(DeviceSpec) // device specification
}
Specifies target device for computation.
Movement Operations
High-level tensor shape transformations. These are converted to explicit INDEX operations during rangeify.
| Operation | Signature | Purpose |
|---|---|---|
Reshape | { src, new_shape } | Change shape, same elements |
Permute | { src, axes: Vec<usize> } | Transpose/reorder axes |
Expand | { src, new_shape } | Broadcast to larger shape |
Pad | { src, begin_pads, end_pads } | Add padding |
Shrink | { src, begins, ends } | Extract sub-region |
Flip | { src, axes: Vec<bool> } | Reverse along axes |
Example: RESHAPE
RESHAPE(new_shape=[6, 4]) : Shape[6, 4]
βββ BUFFER[2, 3, 4] : Float32
βββ CONST([6, 4]) : Shape
Quick Reference
By Category
| Category | Operations |
|---|---|
| Loop Control | RANGE, END |
| Reduction | REDUCE_AXIS, REDUCE, ALLREDUCE |
| Memory | BUFFER, BUFFERIZE, INDEX, POINTER_INDEX, LOAD, STORE |
| Kernel | KERNEL, SINK, AFTER, BARRIER |
| Vector | VECTORIZE, GEP, VCONST, CAT, PTRCAT |
| Expansion | UNROLL, CONTRACT |
| Hardware | WMMA, SPECIAL |
| Control | IF, ENDIF |
| Definition | DEFINE_GLOBAL, DEFINE_LOCAL, DEFINE_VAR, DEFINE_REG, BIND, UNIQUE, DEVICE |
| Movement | RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP |
| ALU | Unary(...), Binary(...), Ternary(...), Cast, BitCast |
Range-Ending Operations
Operations that close RANGE scopes (remove ranges from active set):
| Operation | Range Start Index |
|---|---|
BUFFERIZE | 1 (compute=0, ranges=1+) |
REDUCE | 1 (src=0, ranges=1+) |
STORE | 3 (buffer=0, index=1, value=2, ranges=3+) |
WMMA | 3 (a=0, b=1, c=2) |
END | 1 (computation=0, ranges=1+) |
Expandable Operations
Operations that propagate UNROLL through the computation graph:
- ALU:
Unary,Binary,Ternary - Type:
Cast,BitCast - Vector:
Gep,Vectorize - Memory:
Load,Store,Index,PointerIndex - Control:
Reduce,End,After - Buffer:
Bufferize - Hardware:
Wmma