Keyboard shortcuts

Press ← or β†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Morok

⚠️ Pre-alpha software. APIs are unstable and may change without notice. Not recommended for production use. πŸš§πŸ’€

Morok is a Rust-based ML compiler inspired by Tinygrad. It features lazy tensor evaluation with UOp-based IR, pattern-driven optimization, and multi-backend code generation.

Highlights

FeatureDescription
Declarative Optimizationpatterns! DSL for graph rewrites with Z3-verified correctness
Lazy EvaluationTensors build computation graphs, compiled only at realize()
CUDA SupportUnified memory, D2D copy, LRU buffer caching
Provenance Tracking#[track_caller] traces every UOp to source location
80+ IR OperationsArithmetic, memory, control flow, WMMA tensor cores
20+ OptimizationsConstant folding, tensor cores, vectorization, loop unrolling

Quick Example

#![allow(unused)]
fn main() {
use morok_tensor::Tensor;

// Build lazy computation graph
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], &[3])?;
let c = (a + b).sum();

// Compile and execute
let result = c.realize()?;
}

License

MIT

Hands-On: From Tensors to Models

This chapter teaches Morok through progressive examples. You’ll start with basic tensor operations and build up to a working neural network classifier.

What you’ll learn:

  • Creating and manipulating tensors
  • Shape operations (reshape, transpose, broadcast)
  • Matrix multiplication
  • Building reusable layers
  • Composing a complete model

Prerequisites:

  • Basic Rust knowledge
  • Add morok_tensor to your Cargo.toml

Key pattern: Morok uses lazy evaluation. Operations build a computation graph without executing. Call realize() to compile and run everything at once.


Example 1: Hello Tensor

Let’s create tensors, perform operations, and get results.

use morok_tensor::Tensor;

fn main() -> Result<(), Box<dyn std::error::Error>> {
    // Create tensors from slices
    let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);
    let b = Tensor::from_slice(&[10.0f32, 20.0, 30.0, 40.0]);

    // Lazy operations (no execution yet)
    let sum = &a + &b;
    let scaled = &sum * &Tensor::from_slice(&[0.1f32]);

    // Execute and get results
    let result = scaled.realize()?;
    let data = result.to_ndarray::<f32>()?;
    println!("Result: {:?}", data);
    // Output: [1.1, 2.2, 3.3, 4.4]

    Ok(())
}

What’s happening:

  1. Tensor::from_slice() creates a tensor from a Rust slice. The f32 suffix tells Rust the element type.

  2. &a + &b doesn’t compute anything yet. It returns a new Tensor that represents the addition. The & borrows the tensors so we can reuse them.

  3. realize() is where the magic happens. Morok:

    • Analyzes the computation graph
    • Fuses operations where possible
    • Generates optimized code
    • Executes on the target device
  4. to_ndarray() extracts the result as an ndarray::ArrayD for inspection.

Try this: Remove the realize() call. The code still runs, but data would be emptyβ€”nothing was computed.


Example 2: Shape Gymnastics

Neural networks constantly reshape data. Let’s master the basics.

#![allow(unused)]
fn main() {
fn shape_example() -> Result<(), Box<dyn std::error::Error>> {
    // Create a 1D tensor with 6 elements
    let data = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0]);
    println!("Original shape: {:?}", data.shape());  // [6]

    // Reshape to a 2x3 matrix
    let matrix = data.try_reshape(&[2, 3])?;
    println!("Matrix shape: {:?}", matrix.shape());  // [2, 3]
    // [[1, 2, 3],
    //  [4, 5, 6]]

    // Transpose to 3x2
    let transposed = matrix.try_transpose(0, 1)?;
    println!("Transposed shape: {:?}", transposed.shape());  // [3, 2]
    // [[1, 4],
    //  [2, 5],
    //  [3, 6]]

    // Broadcasting: add a row vector to every row
    // [3, 2] + [1, 2] β†’ [3, 2]
    let bias = Tensor::from_slice(&[100.0f32, 200.0])
        .try_reshape(&[1, 2])?;
    let biased = &transposed + &bias;

    let result = biased.realize()?;
    println!("{:?}", result.to_ndarray::<f32>()?);
    // [[101, 204],
    //  [102, 205],
    //  [103, 206]]

    Ok(())
}
}

Key operations:

OperationWhat it does
try_reshape(&[2, 3])Change shape (same total elements)
try_reshape(&[-1, 3])Infer dimension from total size
try_transpose(0, 1)Swap dimensions 0 and 1
try_squeeze(dim)Remove dimension of size 1
try_unsqueeze(dim)Add dimension of size 1

Broadcasting rules (same as NumPy/PyTorch):

  • Shapes align from the right
  • Each dimension must match or be 1
  • Dimensions of size 1 are β€œstretched” to match
[3, 2] + [1, 2] β†’ [3, 2]  βœ“ (1 broadcasts to 3)
[3, 2] + [2]    β†’ [3, 2]  βœ“ (implicit [1, 2])
[3, 2] + [3]    β†’ error   βœ— (2 β‰  3)

Example 3: Matrix Multiply

Matrix multiplication is the workhorse of neural networks. Every layer uses it.

#![allow(unused)]
fn main() {
fn matmul_example() -> Result<(), Box<dyn std::error::Error>> {
    // Input: 4 samples, 3 features each β†’ shape [4, 3]
    let input = Tensor::from_slice(&[
        1.0f32, 2.0, 3.0,   // sample 0
        4.0, 5.0, 6.0,      // sample 1
        7.0, 8.0, 9.0,      // sample 2
        10.0, 11.0, 12.0,   // sample 3
    ]).try_reshape(&[4, 3])?;

    // Weights: 3 inputs β†’ 2 outputs β†’ shape [3, 2]
    let weights = Tensor::from_slice(&[
        0.1f32, 0.2,  // feature 0 β†’ outputs
        0.3, 0.4,     // feature 1 β†’ outputs
        0.5, 0.6,     // feature 2 β†’ outputs
    ]).try_reshape(&[3, 2])?;

    // Matrix multiply: [4, 3] @ [3, 2] β†’ [4, 2]
    let output = input.dot(&weights)?;

    let result = output.realize()?;
    println!("Output shape: {:?}", result.shape());  // [4, 2]
    println!("{:?}", result.to_ndarray::<f32>()?);
    // Each row: weighted sum of that sample's features

    Ok(())
}
}

Shape rules for dot():

LeftRightResult
[M, K][K, N][M, N]
[K][K, N][N] (vector-matrix)
[M, K][K][M] (matrix-vector)
[B, M, K][B, K, N][B, M, N] (batched)

The inner dimensions must match (the K). Think of it as: β€œfor each row of left, dot product with each column of right.”


Example 4: Building a Linear Layer

A linear layer computes y = x @ W.T + b. Let’s build one from scratch.

#![allow(unused)]
fn main() {
use morok_tensor::{Tensor, Error};

struct Linear {
    weight: Tensor,  // shape: [out_features, in_features]
    bias: Tensor,    // shape: [out_features]
}

impl Linear {
    fn new(in_features: usize, out_features: usize) -> Self {
        // Simple initialization (real code would use proper random init)
        let weight_data: Vec<f32> = (0..in_features * out_features)
            .map(|i| (i as f32 * 0.1).sin() * 0.1)
            .collect();
        let bias_data = vec![0.0f32; out_features];

        Self {
            weight: Tensor::from_slice(&weight_data)
                .try_reshape(&[out_features as isize, in_features as isize])
                .expect("reshape failed"),
            bias: Tensor::from_slice(&bias_data),
        }
    }

    fn forward(&self, x: &Tensor) -> Result<Tensor, Error> {
        // y = x @ W.T + b
        let weight_t = self.weight.try_transpose(0, 1)?;
        let out = x.dot(&weight_t)?;
        Ok(&out + &self.bias)
    }
}

fn linear_example() -> Result<(), Box<dyn std::error::Error>> {
    // Create a layer: 4 inputs β†’ 2 outputs
    let layer = Linear::new(4, 2);

    // Single sample with 4 features
    let input = Tensor::from_slice(&[1.0f32, 2.0, 3.0, 4.0]);

    // Forward pass
    let output = layer.forward(&input)?;

    let result = output.realize()?;
    println!("Output: {:?}", result.to_ndarray::<f32>()?);

    Ok(())
}
}

Why transpose the weights?

PyTorch convention stores weights as [out_features, in_features]. For a layer mapping 4 β†’ 2:

  • Weight shape: [2, 4]
  • Input shape: [4] or [batch, 4]
  • We need: input @ weight.T = [batch, 4] @ [4, 2] = [batch, 2]

This convention makes it easy to read the weight matrix: row i contains all weights feeding into output i.


Example 5: MNIST Classifier

Let’s build a complete neural network that could classify handwritten digits.

#![allow(unused)]
fn main() {
/// Two-layer neural network for MNIST
/// Architecture: 784 (28Γ—28 pixels) β†’ 128 (hidden) β†’ 10 (digits)
struct MnistNet {
    fc1: Linear,
    fc2: Linear,
}

impl MnistNet {
    fn new() -> Self {
        Self {
            fc1: Linear::new(784, 128),
            fc2: Linear::new(128, 10),
        }
    }

    fn forward(&self, x: &Tensor) -> Result<Tensor, Error> {
        // Layer 1: linear + ReLU activation
        let x = self.fc1.forward(x)?;
        let x = x.relu()?;

        // Layer 2: linear (no activation β€” raw logits)
        self.fc2.forward(&x)
    }

    fn predict(&self, x: &Tensor) -> Result<Tensor, Error> {
        let logits = self.forward(x)?;
        // Convert logits to probabilities
        logits.softmax(-1)
    }
}

fn mnist_example() -> Result<(), Box<dyn std::error::Error>> {
    let model = MnistNet::new();

    // Simulate a 28Γ—28 grayscale image (flattened to 784)
    let fake_image: Vec<f32> = (0..784)
        .map(|i| (i as f32) / 784.0)
        .collect();
    let input = Tensor::from_slice(&fake_image)
        .try_reshape(&[1, 784])?;  // batch size 1

    // Forward pass
    let logits = model.forward(&input)?;
    let probs = logits.softmax(-1)?;

    // Get results
    let probs_result = probs.realize()?;
    println!("Probabilities: {:?}", probs_result.to_ndarray::<f32>()?);

    // Get predicted class
    let prediction = logits.argmax(Some(-1))?;
    let pred_result = prediction.realize()?;
    println!("Predicted digit: {:?}", pred_result.to_ndarray::<i32>()?);

    Ok(())
}
}

Key concepts:

  1. ReLU activation: x.relu() returns max(0, x). It introduces non-linearityβ€”without it, stacking linear layers would just be one big linear layer.

  2. Logits vs probabilities: The raw output of the last layer (logits) can be any real number. softmax() converts them to probabilities that sum to 1.

  3. argmax: Returns the index of the maximum valueβ€”the predicted class.

  4. Batch dimension: We use shape [1, 784] for a single image. For 32 images, use [32, 784]. The model handles batches automatically.


Example 6: Under the Hood

Want to see what Morok generates? Here’s how to inspect the IR and generated code.

#![allow(unused)]
fn main() {
fn inspect_compilation() -> Result<(), Box<dyn std::error::Error>> {
    let a = Tensor::from_slice(&[1.0f32, 2.0, 3.0]);
    let b = Tensor::from_slice(&[4.0f32, 5.0, 6.0]);
    let c = &a + &b;

    // Print the computation graph (before compilation)
    println!("=== IR Graph ===");
    println!("{}", c.uop().tree());

    // Compile and execute
    let result = c.realize()?;

    // Inspect generated kernels
    println!("\n=== Generated Kernels ===");
    for (i, kernel) in result.kernels().iter().enumerate() {
        println!("Kernel {}: {}", i, kernel.name);
        println!("Backend: {}", kernel.backend);
        println!("Code:\n{}\n", kernel.code);
    }

    Ok(())
}
}

What you’ll see:

  1. IR Graph: The UOp tree shows operations like BUFFER, LOAD, ADD, STORE. This is Morok’s intermediate representation before optimization.

  2. Generated Code: The actual LLVM IR or GPU code that runs. Notice how Morok fuses the loads and add into a single kernelβ€”no intermediate buffers needed.

Debugging tip: If something seems slow or wrong, print the IR tree. Look for:

  • Unexpected operations (redundant reshapes, extra copies)
  • Missing fusion (separate kernels where one would do)
  • Shape mismatches (often the root cause of errors)

Summary

You’ve learned the core patterns for using Morok:

TaskCode
Create tensorTensor::from_slice(&[1.0f32, 2.0])
Arithmetic&a + &b, &a * &b, -&a
Reshapet.try_reshape(&[2, 3])?
Transposet.try_transpose(0, 1)?
Matrix multiplya.dot(&b)?
Activationt.relu()?, t.softmax(-1)?
Executet.realize()?
Extract dataresult.to_ndarray::<f32>()?

The lazy evaluation pattern:

  1. Build your computation graph with operations
  2. Call realize() once at the end
  3. Morok optimizes and executes everything together

Next steps:

From Tensor to Machine Code

In most ML frameworks, computation happens immediately. Write a + b in PyTorch and it runs nowβ€”the GPU crunches numbers before you can even inspect the result. This eager execution is simple to understand, but it leaves optimization opportunities on the table. How can a compiler optimize a computation it hasn’t seen yet?

Morok takes the opposite approach: lazy evaluation. When you write a.try_add(&b)?, nothing computes. Morok builds a graph describing what to compute, not when. The magic happens when you call realize()β€”that single method triggers the entire compilation pipeline, from high-level tensor operations down to JIT-compiled machine code.

This chapter traces that journey.

tensor.realize()
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  LAZY GRAPH                                             β”‚
β”‚  Tensor ops build UOp DAG (no computation yet)          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  RANGEIFY                                               β”‚
β”‚  Movement ops β†’ explicit RANGE loops                    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  KERNEL SPLITTING                                       β”‚
β”‚  Split at STORE boundaries β†’ multiple KERNELs          β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  OPTIMIZATION & CODEGEN                                 β”‚
β”‚  Heuristics/beam β†’ LLVM IR β†’ JIT compile               β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  EXECUTION                                              β”‚
β”‚  Parallel kernel launch β†’ result buffer                β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Each box is a distinct phase. Let’s walk through them.


Lazy Evaluation: Building the Graph

A Tensor in Morok is surprisingly lightweight:

#![allow(unused)]
fn main() {
pub struct Tensor {
    entry: Arc<TensorEntry>,      // Computation graph
    buffer: Option<Arc<Buffer>>,  // Materialized data (if any)
}
}

The entry holds a TensorEntry containing the UOp graphβ€”the computation this tensor represents. The buffer is optional: lazy tensors don’t have one, only realized tensors do.

Three Ways to Create Tensors

1. Input tensors β€” buffer allocated immediately:

#![allow(unused)]
fn main() {
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
// `a.buffer` = Some(Arc<Buffer>) with actual data
}

When you create a tensor from data, Morok allocates device memory and copies your bytes. The UOp graph contains a BUFFER node pointing to this allocation.

2. Lazy operations β€” no buffer, only graph:

#![allow(unused)]
fn main() {
let b = a.try_add(&a)?;   // b.buffer = None
let c = b.try_mul(&a)?;   // c.buffer = None
}

Arithmetic operations don’t compute anything. They build a UOp graph: Binary(Add, a.uop, a.uop). The tensor exists purely as a description of future work.

3. Movement operations β€” shares the original buffer:

#![allow(unused)]
fn main() {
let d = a.try_reshape(&[1, 3])?;  // d.buffer = same as a.buffer
}

Reshape, permute, and similar operations create new views of existing data. The buffer is shared; only the UOp graph changes to describe the new indexing.

The Global Registry

Morok maintains three global maps (lock-free, thread-safe):

MapKey β†’ ValuePurpose
TENSORStensor_id β†’ Weak<TensorEntry>Track all tensors for graph substitution
BUFFERSuop_id β†’ Arc<Buffer>Find buffers during scheduling
UOP_TO_TENSORuop_id β†’ tensor_idSecondary index for lookups

This registry enables a critical feature: global graph substitution. When an optimization transforms a UOp, all tensors referencing that UOp automatically see the updated version. No stale references, no manual updates.

Hash Consing in Action

Because UOps use hash consing (content-based deduplication), identical computations share memory:

#![allow(unused)]
fn main() {
let x = a.try_add(&b)?;
let y = a.try_add(&b)?;
// x.uop() and y.uop() point to the SAME Arc<UOp>
}

This matters for caching: when we compile kernels, we cache by UOp ID. Hash consing means identical computations automatically hit the cache, even if constructed separately.


Rangeify: Making Loops Explicit

When you write tensor.reshape([2, 3]).expand([4, 2, 3]).sum(axis=0), those movement operations (reshape, expand) are high-level descriptions. To generate actual loops, we need explicit iteration structure.

Rangeify transforms movement operations into RANGE loops and INDEX arithmetic. The entry point is rangeify() in schedule/src/rangeify/transforms.rs.

The 8-Pass Pipeline

Rangeify isn’t a single transformationβ€”it’s eight coordinated passes:

PassPurpose
1. Range AssignmentCreate RANGE UOps for each tensor dimension
2. Early RewritesRemove DETACH, clean up trivial RESHAPE
3. Split Large ReductionsTwo-stage reduce for huge arrays (ratio > 32768)
4. Core RangeifyReduceAxis β†’ REDUCE, bufferization, movement removal
5. Buffer FoldingConstant propagation through buffer expressions
6. Dead Axis RemovalFilter ranges that don’t affect the output
7. Cost-Based Buffer RemovalInline buffers when profitable (PContig optimization)
8. Reduction SimplificationLift range-independent code out of reductions

Each pass uses pattern-based rewriting (see the Pattern-Based Optimization chapter). Patterns fire until no more match, then the next pass begins.

Before and After

Consider this tensor expression:

Before: BUFFER.reshape([2, 3]).expand([4, 2, 3]).sum(axis=0)

After rangeify, movement ops become explicit index computations:

After:
STORE
β”œβ”€β”€ INDEX[RANGE(0..2), RANGE(0..3)]
└── REDUCE(Add)
    β”œβ”€β”€ LOAD
    β”‚   └── INDEX[RANGE(0..4), RANGE(0..2), RANGE(0..3)]
    └── RANGE(0..4, Reduce)

The EXPAND became a RANGE(0..4) that doesn’t affect the buffer indexβ€”broadcasting. The RESHAPE became different index arithmetic. The SUM became REDUCE(Add) with the first range marked as Reduce type.

Movement β†’ Index Arithmetic

Each movement operation has a specific transformation:

OperationTransformation
RESHAPEFlatten/unflatten index expressions
PERMUTEReorder dimensions in INDEX
EXPANDIndex becomes 0 (or range doesn’t affect index)
PADWHERE(in_bounds, LOAD, pad_value)
SHRINKOffset adjustment in INDEX
FLIPsize - 1 - index

After rangeify, there are no more movement opsβ€”just arithmetic operations on indices.


Kernel Splitting: Finding the Boundaries

A computation graph might have multiple outputs, or intermediate values that need materialization. Kernel splitting identifies these boundaries and creates separate kernels.

The entry point is run_kernel_split_pipeline() in schedule/src/rangeify/kernel.rs.

Two-Phase Transformation

Phase 1: BUFFERIZE β†’ STORE

BUFFERIZE nodes mark where values should materialize. Phase 1 converts them to explicit STORE operations:

Before: BUFFERIZE(computation, ranges)
After:  END(STORE(buffer, INDEX(...), computation), ranges)

The END wrapper captures which ranges scope this store. Buffers are allocated and assigned IDs during this phase.

Phase 2: STORE β†’ KERNEL

Each STORE becomes its own kernel:

Before: END(STORE(...), ranges)
After:  KERNEL(SINK(STORE(...)), ranges, buffer_list)

The KERNEL node wraps everything: the computation (as a SINK), the iteration ranges, and the list of buffers this kernel reads and writes.

Tracking Dependencies

When one kernel’s output feeds another kernel’s input, we need dependency tracking:

  1. fix_assign() maps each buffer_id to the kernel that writes it
  2. When kernel B reads a buffer written by kernel A, B depends on A
  3. resolve_kernel_dependencies() builds the dependency graph

Dependencies appear as AFTER nodes in the IR, ensuring kernels execute in valid order.

Buffer Renumbering

Each kernel sees buffers in a specific order (outputs first, then inputs). renumber_define_globals() remaps buffer IDs to match this ordering:

Original: buffer_3, buffer_1, buffer_7
Kernel view: buffer_0 (output), buffer_1, buffer_2 (inputs)

This simplifies code generationβ€”buffer N is always argument N.


Schedule Creation: Preparing for Execution

Once kernels are split, we need to schedule them: determine execution order, allocate buffers, and prepare for compilation.

create_schedule() in tensor/src/schedule.rs produces a Vec<ScheduleItem>:

#![allow(unused)]
fn main() {
pub struct ScheduleItem {
    pub kernel: Arc<UOp>,              // KERNEL wrapper
    pub ast: Arc<UOp>,                 // Inner computation (for codegen)
    pub buffers: Vec<Buffer>,          // Device buffers
    pub dependencies: Vec<u64>,        // Producer kernel IDs
    pub fixedvars: HashMap<String, i64>,  // Bound iteration variables
}
}

Buffer Allocation Strategy

  • Input buffers: Already allocated (from Tensor::from_slice)
  • Intermediate buffers: Allocated during scheduling (for kernel outputs that feed other kernels)
  • Output buffer: Allocated and registered with the final tensor

Parallel Group Analysis

Not all kernels need sequential execution. Independent kernels can run in parallel:

Kernel A (writes buf0)
Kernel B (writes buf1)  ─── no dependency ─── can run in parallel
Kernel C (reads buf0, buf1)  ─── depends on A and B

The scheduler uses Kahn’s algorithm to find parallel groups:

  1. Build the kernel dependency DAG
  2. Find all kernels with no incoming edges β†’ Group 1
  3. Remove Group 1, repeat β†’ Group 2, etc.

Each group’s kernels execute in parallel, then the next group starts.


Code Generation: From UOp to LLVM IR

With kernels scheduled, we generate actual code. Morok currently supports the LLVM backend:

BackendCompile SpeedOutput QualityUse Case
LLVMSlowerHighly optimizedProduction

The Renderer trait abstracts code generation:

#![allow(unused)]
fn main() {
pub trait Renderer {
    fn render(&self, uop: &Arc<UOp>, name: Option<&str>) -> Result<RenderedKernel>;
}
}

LLVM CPU Renderer

The LLVM renderer (codegen/src/llvm/cpu/) traverses the UOp graph and emits LLVM IR:

define void @kernel_0(ptr %args, ptr %vars) {
entry:
  %buf0 = load ptr, ptr %args
  %buf1 = load ptr, ptr getelementptr(ptr, ptr %args, i64 1)
  ; ... loop nest ...
  br label %loop_0

loop_0:
  %i = phi i64 [ 0, %entry ], [ %i.next, %loop_0 ]
  ; ... computation ...
  %i.next = add i64 %i, 1
  %cond = icmp slt i64 %i.next, 128
  br i1 %cond, label %loop_0, label %exit

exit:
  ret void
}

The generated kernel takes two arguments:

  • args: Array of buffer pointers
  • vars: Array of symbolic variable values (for dynamic shapes)

Post-Optimization Passes

Before code generation, 13+ pattern-based passes clean up the IR:

PassPurpose
pm_add_loadsWrap INDEX operations in LOAD
pre_expandConvert UNROLL/UPCAST ranges to explicit operations
devectorizeGroup contiguous memory accesses
pm_reduce_devectorizeHandle vector reductions (K-vec, bool, horizontal)
pm_fma_decompositionConvert a*b+c to fused multiply-add
bool_storage_patternsConvert bool ↔ uint8 for memory operations

These passes transform the optimized AST into a form suitable for code generation. The result is clean, vectorized code with proper memory access patterns.


Execution: Running the Kernels

Code generation produces LLVM IR strings. Execution involves JIT compilation and kernel launch.

The ExecutionPlan

prepare_execution_plan() builds an ExecutionPlan:

#![allow(unused)]
fn main() {
pub struct ExecutionPlan {
    kernels: Vec<PreparedKernel>,       // Compiled kernels
    parallel_groups: Vec<ParallelGroup>,
    buffers: Vec<Buffer>,
    output_buffer_idx: usize,
}
}

The plan is reusable: compile once, execute many times with different data.

JIT Compilation

The LLVM runtime (runtime/src/llvm.rs) compiles IR to machine code:

  1. Parse the LLVM IR string into a module
  2. Verify the module is well-formed
  3. Optimize with LLVM’s O3 pass pipeline
  4. JIT compile to native machine code
  5. Cache by (AST ID, device) for reuse
#![allow(unused)]
fn main() {
// Simplified JIT flow
let module = Module::parse_ir(context, ir_string)?;
module.verify()?;
pass_manager.run(&module);  // O3 optimization
let function = execution_engine.get_function::<KernelFn>(&name)?;
// Cache: (ast_id, device) β†’ function
}

Parallel Execution

With kernels compiled, execution follows the parallel groups:

#![allow(unused)]
fn main() {
for group in &plan.parallel_groups {
    if group.kernel_indices.len() == 1 {
        // Single kernel: direct call
        execute_kernel(&kernels[group.kernel_indices[0]]);
    } else {
        // Multiple kernels: parallel execution
        rayon::scope(|s| {
            for &idx in &group.kernel_indices {
                s.spawn(|_| execute_kernel(&kernels[idx]));
            }
        });
    }
}
}

Independent kernels run in parallel using Rayon’s work-stealing scheduler.

Kernel Caching

Hash consing makes kernel caching highly effective:

  • Key: (UOp ID, device string)
  • Storage: Lock-free HashMap (papaya crate)
  • Hit rate: High, because identical computations share UOp IDs

When you compute the same expression twice, the second call hits the cacheβ€”no recompilation.


Worked Example: Matrix Multiply

Let’s trace C = A @ B through the entire pipeline. Assume 4Γ—4 matrices.

Stage 1: Lazy Graph Construction

#![allow(unused)]
fn main() {
let a = Tensor::from_slice(&a_data, &[4, 4])?;  // Input buffer allocated
let b = Tensor::from_slice(&b_data, &[4, 4])?;  // Input buffer allocated
let c = a.matmul(&b)?;                           // Graph built, no computation
}

At this point, c is a lazy tensor with this UOp graph:

REDUCE_AXIS(Add, axis=2)
└── MUL
    β”œβ”€β”€ EXPAND(A, [4, 4, 4])    β€” A: [4, 4] β†’ [4, 1, 4] β†’ [4, 4, 4]
    └── EXPAND(B, [4, 4, 4])    β€” B: [4, 4] β†’ [1, 4, 4] β†’ [4, 4, 4]

Stage 2: Rangeify

Movement ops become explicit loops:

STORE
β”œβ”€β”€ BUFFER(C)
β”œβ”€β”€ INDEX[RANGE(i, 0..4), RANGE(j, 0..4)]
└── REDUCE(Add)
    β”œβ”€β”€ MUL
    β”‚   β”œβ”€β”€ LOAD(A)
    β”‚   β”‚   └── INDEX[RANGE(i), RANGE(k, 0..4, Reduce)]
    β”‚   └── LOAD(B)
    β”‚       └── INDEX[RANGE(k), RANGE(j)]
    └── RANGE(k, Reduce)

The i and j ranges are output dimensions. The k range is the reduction (contracted) dimension.

Stage 3: Kernel Splitting

Single STORE β†’ single KERNEL:

KERNEL
β”œβ”€β”€ SINK(STORE(...))
β”œβ”€β”€ ranges: [i: 0..4, j: 0..4]
└── buffers: [C (output), A (input), B (input)]

Stage 4: Schedule

One ScheduleItem with:

  • kernel: The KERNEL UOp
  • ast: The inner SINK/STORE
  • buffers: [C, A, B]
  • dependencies: [] (no prior kernels)

Stage 5: Optimization

Heuristic optimizer applies:

  • Vectorization: UPCAST j dimension by 4
  • Loop ordering: Ensure good cache behavior

Stage 6: Code Generation

Generated LLVM IR (simplified):

define void @matmul(ptr %args, ptr %vars) {
entry:
  %C = load ptr, ptr %args
  %A = load ptr, ptr getelementptr(ptr, ptr %args, i64 1)
  %B = load ptr, ptr getelementptr(ptr, ptr %args, i64 2)
  br label %loop_i

loop_i:
  %i = phi i64 [ 0, %entry ], [ %i.next, %loop_i.end ]
  br label %loop_j

loop_j:
  %j = phi i64 [ 0, %loop_i ], [ %j.next, %loop_k.end ]
  %acc = ... ; initialize accumulator
  br label %loop_k

loop_k:
  %k = phi i64 [ 0, %loop_j ], [ %k.next, %loop_k ]
  %a_val = load float, ptr ...  ; A[i, k]
  %b_val = load float, ptr ...  ; B[k, j]
  %prod = fmul float %a_val, %b_val
  %acc.new = fadd float %acc, %prod
  %k.next = add i64 %k, 1
  %k.cond = icmp slt i64 %k.next, 4
  br i1 %k.cond, label %loop_k, label %loop_k.end

loop_k.end:
  store float %acc.new, ptr ...  ; C[i, j]
  ; ... continue j, i loops
}

Stage 7: Execution

  1. JIT compile the LLVM IR
  2. Execute: kernel([C_ptr, A_ptr, B_ptr], [])
  3. Result is in C buffer

Total: one function call, result ready.


Comparison: How Other Frameworks Execute

AspectPyTorchJAXTVMMorok
EvaluationEager (immediate)Traced (jit decorator)Lazy (te.compute)Lazy (realize)
Graph capturetorch.compilejax.jit traceExplicit scheduleImplicit via ops
CompilationTorchInductorXLA backendAuto-schedulerPattern + beam
CachingPer-graph hashPer-tracePer-schedulePer-AST (hash consing)
ParallelismDataParallel/DDPpmap/pjitParallel scheduleParallel groups

PyTorch: Eager by default, torch.compile for optimization. TorchInductor generates Triton or C++ code.

JAX: Functional transformations (jit, grad, vmap) trace computations. XLA compiles to optimized kernels.

TVM: Explicit separation of computation and schedule. Auto-scheduler searches for good schedules.

Morok: Fully lazyβ€”nothing executes until realize(). Hash consing provides automatic caching. Pattern-based optimization with optional beam search for production quality.


The Deeper Insight

The pipeline embodies several design principles:

Lazy evaluation enables global optimization. By deferring computation, we see the entire graph before generating code. No local decision limits global optimization.

Explicit loops enable hardware-specific scheduling. Movement ops are convenient abstractions, but GPUs need loops. Rangeify bridges the gap.

Hash consing makes caching automatic. Identical computations share pointers, so cache keys are trivial. No complex graph hashing needed.

Separation of concerns keeps each stage simple. Rangeify doesn’t know about LLVM. Code generation doesn’t know about tensor semantics. Each stage does one thing well.

The result: a compilation pipeline that’s both powerful and maintainable. From tensor.realize() to machine code, every step is visible, debuggable, and extensible.

Path of the UOp: The 22-Stage Codegen Pipeline

A UOp starts as a high-level tensor expression. By the time it reaches the hardware, it has been transformed through 22 distinct stagesβ€”each with a specific purpose, each building on the last. This chapter traces that journey.

The pipeline is a proven design for tensor compilation. Understanding it means understanding how tensor expressions become machine code.


How to Read This Chapter

If you’re not a compiler engineer, this chapter might seem intimidating. Here’s what you need to understand before diving in.

Key Concepts

UOp (Micro-Operation)

  • Think of it as a node in a flowchart representing one computation
  • Example: ADD(a, b) means β€œadd a and b”

Pattern

  • A find-and-replace rule for code structures (not text)
  • Example: β€œIf you see ADD(x, 0), replace with x”
  • Patterns fire repeatedly until no more matches (fixpoint)

Range

  • A loop iteration: RANGE(0..10) means β€œfor i from 0 to 10”

AxisType

  • What kind of loop is this?
    • Global: Parallel across GPU blocks / CPU threads
    • Local: Parallel within a workgroup
    • Reduce: Accumulator (sum, max, etc.)
    • Loop: Sequential iteration

Stage

  • One transformation pass through the code
  • Patterns fire until fixpoint, then move to the next stage

Reading Strategy

  1. First pass: Read just the β€œWhat This Does” and β€œWhy This Matters” sections
  2. Second pass: Look at the diagrams and examples
  3. Third pass (if you want details): Read the pattern descriptions

Questions to Ask

For each stage, ask:

  • What does this stage accomplish? (High-level goal)
  • Why do we need this stage? (Motivation)
  • What would go wrong without it? (Consequences)

Overview

The 22 stages fall into four phases:

Tensor Expression
       β”‚
       β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ RANGEIFY (Stages 1-7)               β”‚
β”‚ Movement ops β†’ Explicit loops       β”‚
β”‚                                     β”‚
β”‚ [Make iteration explicit,           β”‚
β”‚  optimize ranges]                   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       β”‚
       β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ EXPANDER (Stages 8-10)              β”‚
β”‚ UNROLL/UPCAST β†’ Explicit vectors    β”‚
β”‚                                     β”‚
β”‚ [Expand optimization primitives]    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       β”‚
       β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ DEVECTORIZER (Stages 11-15)         β”‚
β”‚ Vector ops β†’ Scalar code            β”‚
β”‚                                     β”‚
β”‚ [Lower to hardware-specific ops]    β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       β”‚
       β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚ LINEARIZER (Stages 16-22)           β”‚
β”‚ IR β†’ Linear instruction sequence    β”‚
β”‚                                     β”‚
β”‚ [Serialize to executable code]      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
       β”‚
       β–Ό
  Machine Code

Each stage applies pattern-based rewrites. Patterns fire until fixpoint, then the next stage begins.


Phase 1: Rangeify

Goal: Transform high-level movement operations into explicit loop structures and optimize ranges.


Stage 1: Early Movement Ops

Stage at a Glance

Goal: Clean up movement operations before range assignment Key Patterns: Movement on INDEX, movement through wrappers, nested INDEX simplification Impact: Prevents missed optimizations later in the pipeline

What This Does: This stage cleans up movement operations by pushing index manipulations into places where they’re actually needed. Think of it as organizing your desk before filing papersβ€”move instructions closer to where the data is used.

Why This Matters: Movement operations (RESHAPE, PERMUTE, etc.) are convenient abstractions, but the hardware needs concrete index calculations. By cleaning them up early, we ensure patterns in later stages can match correctly.

Pattern: pm_mops + pm_syntactic_sugar (bottom-up)

PatternTransformationVisualLocation
Movement on INDEXApply movement to index expressionsINDEX(PERMUTE(arr), [i, j]) β†’ INDEX(arr, [j, i])movement_op_patterns()
Movement through AFTERMove RESHAPE through timing wrapper (Tinygrad-specific)AFTER(RESHAPE(x, arg), [dep1, dep2]) β†’ RESHAPE(AFTER(x, [dep2]), arg)Tinygrad only
Movement through ENDUnwrap movement from END wrapper (Tinygrad-specific)END(RESHAPE(x), ranges) β†’ END(x, ranges)Tinygrad only
Nested INDEX simplificationRemove redundant nested INDEX (Morok)INDEX(INDEX(ptr, [i]), [i]) β†’ INDEX(ptr, [i])movement_op_patterns()
Nested INDEX concatFlatten nested INDEX for PtrDTypeINDEX(INDEX(ptr, i), j) β†’ INDEX(ptr, i, j)pm_syntactic_sugar

Why bottom-up? Child nodes must be clean before parents can match. Movement ops nest deeply; cleaning from bottom prevents missed patterns.

Note: Tinygrad and Morok have different approaches here. Tinygrad moves movement ops through wrappers (AFTER, END) because it re-applies movement ops during bufferization. Morok removes movement ops entirely by transforming indices during bufferization, so AFTER/END patterns are not needed.

Morok: movement_op_patterns() in rangeify/patterns.rs


Stage 2: Load Collapse

Stage at a Glance

Goal: Eliminate REDUCE operations by detecting range-independent computation Key Patterns: Bounded sum, gated load collapse, general reduce elimination Impact: Converts loop iterations to arithmetic operations

What This Does: Eliminates REDUCE operations by recognizing when the computation can be done without iteration. Uses range-independent computation detection and symbolic simplification.

Why This Matters: Reducing iterations to arithmetic operations eliminates loop overhead. Instead of running a loop 1000 times, compute the answer directly.

Pattern: pm_load_collapse

// Before: Sum with bounds check
sum(1 for k in 0..64 if k >= length)

// After: Compute count directly (NO LOOP!)
count = clamp(64 - length, 0, 64)

The mechanism works by:

  1. Identifying subexpressions that don’t depend on the REDUCE range
  2. Creating DEFINE_VAR for those subexpressions (treats as loop-invariant)
  3. Substituting the range with DEFINE_VAR and running symbolic simplification
  4. If the simplified expression has no more ranges, the REDUCE is eliminated

Note: WHERE movement through INDEX (pm_move_where_on_load in Stage 8) is a separate optimization that places conditionals before loads to skip memory accesses, but it doesn’t eliminate REDUCE operations.

Morok: pm_load_collapse() in rangeify/patterns.rs


Stage 3: Split Ranges

Stage at a Glance

Goal: Enable better optimization through divmod decomposition Key Patterns: Split ranges with modulo, flatten ranges Impact: Inner ranges can vectorize, outer can parallelize

What This Does: Handles modulo patterns by splitting a range into outer and inner components.

Why This Matters: Splitting ranges is like dividing a large task among team members. If you have 12 items and each person does 4, you get 3 people Γ— 4 items. Inner loops (one person’s 4 items) can be fast; outer loops (3 people) can run in parallel.

Pattern: pm_split_ranges + pm_flatten_range

Before:  RANGE(end=12) % 4  // One loop with modulo (slow)
             ↓ [Split into outer Γ— inner]
After:   RANGE(end=3) * 4 + RANGE(end=4)
            ↑outer        ↑inner
            Parallel      Sequential

This enables:

  • Inner ranges can vectorize (SIMD)
  • Outer ranges can parallelize (GPU blocks / CPU threads)

pm_flatten_range merges nested ranges on REDUCE/STORE/END when beneficial.

Context: Requires dictionary context (ctx={}) to track substitutions at SINK.

Note: The split only applies when end % mod == 0 (divisibility check).

Morok: pm_split_ranges() + pm_flatten_range() in rangeify/transforms.rs


Stage 4: Initial Symbolic

Stage at a Glance

Goal: Simplify expressions using algebra rules Key Patterns: Constant folding, identity removal, div-mod recombine Impact: Eliminates expensive operations, reduces code size

What This Does: Applies 100+ constant folding and algebraic simplification rules.

Why This Matters: Computers are fast at simple math. Dividing and taking remainders are slow operations. This stage uses algebra rules to eliminate slow operations whenever possible.

Pattern: sym + pm_flatten_range

Constant folding:

ADD(CONST(2), CONST(3)) β†’ CONST(5)
MUL(x, CONST(1)) β†’ x
ADD(x, CONST(0)) β†’ x

Div-mod recombination:

(x / c) * c + (x % c) β†’ x

Why? Computes the same value as x but with 3 operations instead of 1. This pattern finds and removes the redundancy (common in stride calculations).

Boolean algebra:

x AND x β†’ x
x OR FALSE β†’ x
NOT(NOT(x)) β†’ x

Additional categories:

  • Identity removal (self-folding, redundant operations)
  • Comparison simplification
  • Cast optimization
  • GEP pushing (move address calculations through ALUs)
  • Where folding (combine WHERE with same conditions)
  • Reduce mul chain (move multiplications outside reduce)

Morok: symbolic_patterns() in symbolic/patterns.rs


Stage 5: Simplify Ranges

Stage at a Glance

Goal: Merge adjacent ranges to reduce loop overhead Key Patterns: Range merging with cost analysis Impact: Fewer loops = less overhead

What This Does: Merges adjacent ranges when profitable.

Why This Matters: Merging ranges is like combining multiple small trips into one big one. Instead of going to the store 4 times for 4 items, go once for all 4 items. Saves the overhead of starting and stopping.

Pattern: pm_simplify_ranges

// Before: two separate ranges
RANGE(0..4), RANGE(0..8)

// After: merged (if compatible)
RANGE(0..32)

Merge criteria:

  1. Axis types must be compatible (both output, both reduce, etc.)
  2. REDUCE scope must remain consistent
  3. Cost-based: Accept only if divmod operation count does not increase

The compiler only merges if it saves operations. Merging might require division/modulo to recalculate indices. If that costs more than it saves, merge is skipped.

Morok: simplify_merge_adjacent() in rangeify/transforms.rs


Stage 6: Split Store (CPU-only)

Stage at a Glance

Goal: Avoid branch misprediction by splitting conditional stores Key Patterns: Split store ranges at comparison boundaries Impact: More predictable CPU execution

What This Does: Splits store ranges at conditional boundaries when there are CMPLT(range, const) comparisons in the store’s consumer map.

Why This Matters: Branch misprediction slows down CPUs. Instead of one loop with an if statement that the CPU can’t predict, we create two loops without conditionals. Each loop does predictable work, so the CPU stays fast.

Pattern: pm_split_store

// Before: Store with conditional (branch misprediction risk)
for i in 0..100:
    if i < 50:
        output[i] = data[i]

// After: Two unconditional stores (predictable)
for i in 0..50:   // First loop
    output[i] = data[i]
for i in 50..100: // Second loop
    output[i] = data[i]

The transformation finds constant comparison points in the store’s consumer map and creates disjoint ranges for each segment.

Skipped for GPU devicesβ€”they handle conditionals differently.

Morok: pm_split_store() in rangeify/transforms.rs


Stage 7: Apply Opts

Stage at a Glance

Goal: Find optimal combination of vectorization, unrolling, memory usage Key Algorithm: Beam search or heuristics Impact: Can significantly improve performance

What This Does: The optimization searchβ€”either beam search or heuristicβ€”explores different combinations of optimization actions.

Why This Matters: The compiler tries different combinations of optimizations (vectorize here? unroll there?) and picks the fastest. Finding the right combination can make code 10x faster.

Function: apply_opts(sink, renderer)

Optimization actions:

ActionEffectHardware Target
TCEnable tensor core usage
UPCASTVectorize a dimension
LOCALUse local/shared memory
UNROLLUnroll a loop dimension
GROUPGroup operations for cache
GROUPTOPGroup for reduce ops
THREADThread-based parallelism
NOLOCALSDisable local memory usage
SWAPSwap range assignments
PADTOPad for alignment

Optimization Search Explained:

The compiler searches for the best combination:

  • Heuristic mode (BEAM=0): Fast hand-coded optimization patterns, no compilation
  • Beam search (BEAMβ‰₯1): Compiles and runs candidates to measure actual performance
Optimization Search:
β”œβ”€β”€ Heuristic mode (BEAM=0): Hand-coded optimizations
└── Beam search (BEAMβ‰₯1):
    β”œβ”€β”€ Generate all possible actions (193 combinations)
    β”œβ”€β”€ Apply to all top-K candidates in parallel
    β”œβ”€β”€ Filter based on constraints
    β”œβ”€β”€ Compile and run each candidate β†’ Measure actual time
    └── Pick fastest

Note: NOLOCALS is a constraint that sets dont_use_locals = True, preventing further LOCAL actions and affecting shared memory usage decisions.

Morok: optimizer/mod.rs, optimizer/opts.rs


Phase 2: Expander

Goal: Transform optimization primitives (UNROLL/UPCAST) into explicit operations.


Stage 8: Post-Opt Symbolic

Stage at a Glance

Goal: Symbolic simplification after optimization Key Patterns: WHERE movement, constant folding Impact: Enables better load combining and vectorization

What This Does: Symbolic simplification after optimization, plus WHERE movement.

Why This Matters: WHERE operations are like if statements. This stage moves if checks from after a load to before the load. Hardware can skip loading when the condition is false, saving memory bandwidth.

Pattern: sym + pm_move_where_on_load

// Before: WHERE guards a load
WHERE(valid, LOAD(index), alt)

// After: validity moved to INDEX
LOAD(INDEX(ptr, idx, valid=valid), alt)

Moving validity into INDEX enables better load combining and vectorization.

Note: This pattern only matches when the alternative value is 0. The transformation involves complex clause analysis: duplicate detection, range dependency checks, and data-dependent load verification.

Note: The Morok implementation uses gate= instead of valid= (the Index struct has a gate field). The concept is identical.

Morok: pm_move_where_on_load() in symbolic/patterns.rs


Stage 9: Expander

Stage at a Glance

Goal: Convert UNROLL/UPCAST to explicit operations Key Concepts: UNROLL, CONTRACT, pattern order Impact: Makes vectorization explicit and ready for hardware

What This Does: Transforms UNROLL/UPCAST optimization primitives into explicit operations.

Why This Matters: UPCAST and UNROLL mark intentβ€”what we want to do. This stage makes that intent explicit so the hardware can actually do it.

Pattern: sym + pm_pre_expander + pm_group_for_reduce + expander

⚠️ Important: Pattern Precedence

The patterns are combined and run to fixpoint. The order affects which pattern is tried first when multiple could match:

  1. sym first (symbolic simplification)
  2. pm_pre_expander second (converts UPCAST/UNROLL ranges)
  3. pm_group_for_reduce third (handles GROUP_REDUCE axis)
  4. expander last (main expansion)

Wrong precedence can cause incorrect vectorization or reduction scoping.

UNROLL and CONTRACT:

UNROLL and CONTRACT work together:

UNROLL: "Take this one thing and make N copies for different positions"
Example:  x β†’ [x_0, x_1, x_2, x_3]

CONTRACT: "Take these N things and combine them back"
Example:  [a, b, c, d] β†’ one vector containing all four

Together: UPCAST marks intent to vectorize β†’ UNROLL expands β†’ CONTRACT combines.

UPCAST range β†’ VECTORIZE:

// Before: UPCAST marks vectorization intent
RANGE(end=4, UPCAST)
      ↓ [pm_pre_expander]
// Step 1: Convert to UNROLL with constant indices
UNROLL(VCONST([0, 1, 2, 3]))
      ↓ [expander]
// Step 2: Expand operations with UNROLL sources
// Operations now have unrolled sources
      ↓ [CONTRACT or implicit]
// After: explicit VECTORIZE
VECTORIZE(op[0], op[1], op[2], op[3])

UNROLL range β†’ repeated operations:

When we say β€œoperations duplicated,” it sounds like copy-paste. But that’s not what happens. The compiler creates a single SIMD instruction that processes all N elements together. Think of a SIMD register as a box holding 4 numbers; adding two boxes adds all 8 numbers at once.

// Before: UPCAST marks vectorization intent
RANGE(end=3, UPCAST)
      ↓ [pm_pre_expander]
// Step 1: Convert to UNROLL
UNROLL(VCONST([0, 1, 2]))
      ↓ [expander]
// Step 2: Operations expand to handle all positions
// After: operations processed together (not duplicated)
UNROLL([op_at_0, op_at_1, op_at_2])

UNROLL/END/CONTRACT interaction:

Before: END(STORE(...), [RANGE(UPCAST)])
             ↓ [pm_pre_expander]
Step 1: END(STORE(...), [UNROLL(VCONST([0,1,2,3]))])
             ↓ [expander]
Step 2: END(CONTRACT(STORE(...Γ—4)), [])

Broadcast through AFTER/END:

// Broadcast VECTORIZE (all elements identical)
AFTER(VECTORIZE([x, x, x, x]), deps) β†’ VECTORIZE([AFTER(x, deps), AFTER(x, deps), ...])

GROUP_REDUCE Handling (pm_group_for_reduce):

GROUP_REDUCE is a special axis type for tensor core reductions:

// Before: REDUCE with GROUP_REDUCE ranges
REDUCE(src, [range(GROUP_REDUCE)])
           ↓ [pm_group_for_reduce]
// After: Shared memory reduction pattern
1. Track upstream LOCAL ranges
2. BUFFERIZE result with group ranges (AddrSpace.LOCAL)
3. INDEX into buffer with transformed ranges
4. Final REDUCE with axes (range_id+100, AxisType.REDUCE)

This enables efficient tensor core accumulation via shared memory.

Morok: expand.rs


Stage 10: Add Local Buffers

Stage at a Glance

Goal: Prepare buffers for fast memory (shared / L1) Key Patterns: Bufferize with locals, extract hints Impact: Frequently-accessed data stays in fast memory

What This Does: Prepares buffers for local memory usage and applies codegen-specific cleanups.

Why This Matters: Local buffers = fast memory close to the compute unit:

  • GPU: Shared memory (LDS) β€” 100x faster than global memory
  • CPU: L1 cache β€” 10x faster than main memory

The compiler moves frequently-accessed data to local buffers, similar to keeping important files on your desktop instead of a network drive.

Pattern: pm_add_buffers_local + rangeify_codegen

TransformPurpose
bufferize_to_storeConvert BUFFERIZE with allow_locals=true
get_contiguousExtract optimization hints from CONTIGUOUS
NOOP removalClean up no-op operations
Strip arg from STORERemove redundant arguments
Fix broadcast dtypeEnsure consistent types in broadcasts

Morok: rangeify/kernel.rs


Phase 3: Devectorizer

Goal: Lower from hardware-agnostic vectors to hardware-specific instructions.


Stage 11: Remove Reduce

Stage at a Glance

Goal: Convert declarative REDUCE to imperative accumulation Key Patterns: Reduce to accumulator, horizontal reduction Impact: Maps to hardware reduction instructions

What This Does: Converts high-level REDUCE to accumulator pattern.

Why This Matters: A declarative β€œsum these values” needs to become imperative instructions: initialize accumulator, loop, add each value.

Pattern: pm_reduce + gep_pushing

// Before: declarative reduction
REDUCE(Add, values, range)

// After: imperative accumulation
acc = DEFINE_REG(0.0)
for i in range:
    acc = ADD(acc, values[i])

Horizontal reduction:

Before we loop through a reduction dimension, we first combine neighboring values. This creates larger reductions that map better to hardware instructions.

Before:  [a, b, c, d, e, f, g, h]  // 8 values
             ↓ [Horizontal reduction]
Step 1:  [a+e, b+f, c+g, d+h]      // 4 partial sums
             ↓ [Accumulator pattern]
After:   acc = acc + (a+e) + (b+f) + (c+g) + (d+h)

GEP pushing pushes GEP (get element pointer) operations through ALUs for better vectorization:

GEP(ADD(ptr_a, ptr_b), idx) β†’ ADD(GEP(ptr_a, idx), GEP(ptr_b, idx))

Why? Enables SIMD on the two GEPs (can be computed in parallel).

WMMA Tensor Core Fusion:

// Fuse tensor core accumulation inline
WMMA(a, b, c) + add β†’ WMMA(a, b, c + add)

This pattern enables efficient FMA-style accumulation on NVIDIA tensor cores.

Morok: devectorize.rs


Stage 12: Add GPU Dims

Stage at a Glance

Goal: Map abstract ranges to GPU thread indices Key Patterns: Range to SPECIAL replacement Impact: Enables parallel execution on GPU

What This Does: Replaces ranges with GPU thread indices.

Why This Matters: GPUs have hard limits: max 1024 threads per block, max 48KB shared memory. If your computation needs 2000 threads, the compiler must split it into multiple blocks. Dimension limiting handles this automatically.

Pattern: pm_add_gpudims

// Before: abstract range
RANGE(end=256, Global)

// After: GPU-specific
SPECIAL(gidx0)  // global thread index

Mapping:

Range TypeGPU Equivalent
Global, THREADgidx (global index)
Local, WARP, GROUP_REDUCElidx (local/workgroup index)
ReduceLoop (no mapping)

Dimension Limiting:

GPUs have hardware limits (e.g., max 1024 threads per block). When ranges exceed these limits, the compiler:

  1. Groups adjacent dimensions: [256, 256, 256] with max [256, 256] β†’ [65536, 256]
  2. Splits large dimensions: [2048] with max [1024] β†’ [2, 1024]
  3. Reconstructs indices via divmod

Store Masking:

Global stores that don’t use all local dimensions are masked:

// If STORE doesn't use lidx1, mask it:
STORE(INDEX(...), value) β†’ STORE(INDEX(..., gate=(lidx1 == 0)), value)

This ensures stores only execute when unused local indices are 0.

Morok: gpudims.rs


Stage 13: Add Loads

Stage at a Glance

Goal: Wrap INDEX operations in explicit LOAD Key Patterns: Add LOAD, remove redundant loads Impact: Makes memory operations explicit for codegen

What This Does: Wraps INDEX operations in explicit LOAD.

Why This Matters: Index operations compute addresses. LOAD actually reads memory. Making this explicit helps the code generator understand what memory accesses are needed.

Pattern: pm_add_loads

// Before: bare index
INDEX(ptr, i)

// After: explicit load
LOAD(INDEX(ptr, i))

Also removes redundant loads from stores (write-only access).

Note: Not all INDEX operations get wrapped in LOAD. Pointer types (already addresses) and image textures (special hardware) use different access methods.

Morok: devectorize.rs


Stage 14: Devectorize

Stage at a Glance

Goal: Convert abstract vectors to match hardware capabilities Key Phases: 4 coordinated passes Impact: Vectors work with actual hardware width

What This Does: Handles the transition from abstract vectors to hardware operations.

Why This Matters: Devectorize uses 4 conceptual phases within a single graph_rewrite:

  1. Phase 1: Create PTRCAT to group consecutive pointer accesses, devectorize ALU/WMMA/buffers, expand vector INDEX β†’ GEP(PTRCAT)
  2. Phase 2: Move GEP through LOAD/STORE
  3. Phase 3: Distribute PTRCAT through LOAD/STORE, creating CAT(LOADs), fix image buffers
  4. Phase 4: Split CAT(LOADs) into smaller chunks matching hardware width

PTRCAT Construction:

PTRCAT groups consecutive pointer accesses:

  1. Generate individual indexes for each vector element
  2. Extract (valid, root_src) β†’ [offsets] mapping
  3. Group consecutive offsets by validity and source
  4. Create PTRCAT from grouped pointers
  5. Return with GEP permutation for correct element order

This reduces memory bus transactions.

Device-Specific Fold Lengths:

DeviceFold LengthsNotes
DSP128, 64, 32, 16, 8, 4Large vectors for DSP SIMD
GPU (float4)4, 2Standard GPU vectorization
GPU (half + ALLOW_HALF8)8, 4, 2Half precision with env var
GPU (AMX)16, 8, 4, 2Apple AMX support
Image4Fixed for image textures
Default1Scalar fallback

Environment Variable: DEVECTORIZE

  • 0: Skip devectorize only (keeps correct_load_store)
  • 1: Full devectorization (default)
  • β‰₯2: Skip both devectorize and correct_load_store

Pattern: devectorize + load_store_folding + correct_load_store + load_store_indexing

Split vectorized ALUs:

// If hardware doesn't support vec4 add
ADD(vec4_a, vec4_b) β†’ [ADD(a[0], b[0]), ADD(a[1], b[1]), ...]

Load/store chunk splitting: Match hardware memory width.

Image fixup: Special handling for image tensor buffers.

Morok: devectorize.rs


Stage 15: Lower Index Dtype

Stage at a Glance

Goal: Convert abstract Index type to concrete integers Key Patterns: Operation-specific lowering based on value bounds Impact: Indices use hardware-native integer types (i32 or i64)

What This Does: Converts abstract Index type to concrete integers.

Why This Matters: The Index type is abstractβ€”hardware doesn’t have it. We need to convert to i32 or i64, which the hardware actually supports.

Pattern: pm_lower_index_dtype

// Before: abstract index type
idx: Index

// After: concrete type
idx: i32  // or i64, based on bounds

Operation-Specific Lowering:

Index type lowering is NOT a single castβ€”each operation type has specific patterns:

OperationBeforeAfter
Binary opsADD(Index, Index)ADD(i32, i32) with casts
CONSTCONST(5): IndexCONST(5): i32
WHEREWHERE(c, Index, Index)WHERE(c, i32, i32)
RANGERANGE(end: Index)RANGE(end: i32) with cast
SPECIALSPECIAL(gidx)Always i32 (GPU indices are 32-bit)
DEFINE_VARDEFINE_VAR: Indexi32 if bounds fit, else i64
VECTORIZEVECTORIZE(Index...)Cast each to concrete scalar
CAST cleanupCAST(i32, Index)Just i32 (remove redundant cast)

The select_concrete_dtype() function determines i32 vs i64 using vmin/vmax bounds analysis:

dtype = i32 if bounds fit in [-2^31, 2^31-1] else i64

Morok: symbolic/index_lowering.rs


Phase 4: Linearizer

Goal: Convert the DAG to a linear instruction sequence.


Stage 16: Post-Index Symbolic

Stage at a Glance

Goal: Full symbolic simplification after index lowering Key Patterns: All symbolic rules (140+) Impact: Final cleanup before serialization

What This Does: Full symbolic simplification after index lowering.

Why This Matters: Now that indices are concrete integers (i32/i64), arithmetic can fully simplify. This is the last chance to clean up expressions before linearization.

Pattern: symbolic

Includes GEP pushing patternsβ€”move address calculations through arithmetic:

Before:  GEP(ADD(arr_a, arr_b), idx)
              ↓ [Push GEP through ADD]
After:   ADD(GEP(arr_a, idx), GEP(arr_b, idx))

Why? Enables parallel computation of GEPs and may enable downstream vectorization. (Note: The pattern only applies when GEP’s dtype and ALU’s dtype are NOT pointers.)


Stage 17: Pre-Matcher (Optional)

Stage at a Glance

Goal: Backend-specific patterns before decomposition Key Patterns: Renderer-specific Impact: Hardware-specific optimizations

What This Does: Renderer-specific patterns applied before decomposition.

Why This Matters: Each backend can add its own patterns. For example, DSP backends use this to replace generic patterns with DSP-specific SIMD intrinsics. This allows hardware-specific optimizations without changing the generic pipeline.

Pattern: renderer.pre_matcher

Most backends (CPU, GPU) don’t need this. Only specialized hardware uses it.

Note: Morok does not currently implement this stage. The Renderer trait has only a decompositor() method. This is a future enhancement for DSP and other specialized backends.


Stage 18: Decompositions

Stage at a Glance

Goal: Rewrite operations the target doesn’t support Key Patterns: Power-of-2, transcendental approximations Impact: Maps high-level ops to hardware instructions

What This Does: Late rewrites for operations the target doesn’t support.

Why This Matters: Hardware doesn’t have every operation. For example, most CPUs don’t have a direct sin instruction. We approximate it with operations that do exist (addition, multiplication, etc.).

Pattern: symbolic_simple + get_late_rewrite_patterns

PatternExampleWhen Used
MOD β†’ ANDx % 8 β†’ x & 7Power-of-2 divisor
MUL β†’ SHLx * 16 β†’ x << 4Power-of-2 multiplier
DIV β†’ SHRx // 8 β†’ x >> 3Power-of-2 divisor
FDIV β†’ MULx / 2.0 β†’ x * 0.5Float constant divisor
NEGx * -1 β†’ NEG(x)When NEG supported
MULACCa * b + c β†’ MULACC(a, b, c)When FMA supported
Fast integer divisionx // 7 β†’ (x * M) >> SNon-power-of-2 divisor
De Morgan’s laws(!x) & (!y) β†’ !(x | y)Boolean simplification
Comparison negations!(x < c) β†’ (c-1) < xInteger comparisons

Transcendental function approximations (SIN, EXP, LOG, etc.) are implemented via the decompositor() pathway (see ir/src/decompositions/transcendentals.rs).

Morok: optimizer/mod.rs


Stage 19: Final Rewrite

Stage at a Glance

Goal: Prepare for linearization Key Patterns: CONST vectorization, GEP resolution, END splitting Impact: Clean representation ready for linearization

What This Does: Prepare for linearization.

Why This Matters: Some patterns are easier to apply after decomposition. This stage does final cleanup before converting to a linear sequence.

Pattern: pm_decomp + pm_render + extra_matcher + pm_split_ends

CONST vectorization:

// Make vector constants explicit
CONST(1.0) used as vec4 β†’ VECTORIZE(1.0, 1.0, 1.0, 1.0)

CAT to VECTORIZE (via gep_pushing in symbolic):

CAT(a, b, c, d) β†’ VECTORIZE(a, b, c, d)

CAT cannot be rendered directly; explicit VECTORIZE is required for codegen.

GEP resolution: Convert remaining GEP operations.

Split multi-range ENDs:

// Before: END closing multiple ranges
END(op, [range_a, range_b])

// After: nested single ENDs
END(END(op, range_a), range_b)

extra_matcher: Each backend can add its own final patterns. This allows hardware-specific optimizations without changing the generic pipeline.

Morok: devectorize.rs, linearize/mod.rs, optimizer/mod.rs


Stage 20: Add Control Flow

Stage at a Glance

Goal: Build control flow graph and add range dependencies Key Concept: Three relationship types (nested, dependent, independent) Impact: Correct instruction ordering

What This Does: Builds the control flow graph and adds range dependencies.

Why This Matters: Operations must execute in a valid order. If a load uses a RANGE’s value, the RANGE must come first. This stage tracks and enforces these dependencies.

Pattern: pm_add_control_flow (bottom-up)

// Analyze which END operations depend on which
END(computation, [RANGE_A]) and END(other_computation, [RANGE_B]) are siblings
β†’ Creates edge: RANGE_B.src += END(computation)

// Add explicit dependency
RANGE_B waits for RANGE_A to complete

Three relationship types:

RelationshipExampleMeaning
NestedRANGE_A inside RANGE_BA must complete before B starts
DependentLOAD_A uses RANGE_ARANGE_A must precede LOAD_A
IndependentRANGE_X and RANGE_Y don’t interactCan run in parallel

Bottom-up traversal ensures dependencies flow correctly from leaves to roots.

Morok: schedule/src/linearize/mod.rs


Stage 21: Linearize

Stage at a Glance

Goal: Convert DAG to linear instruction sequence Key Algorithm: Priority-aware topological sort Impact: Valid execution order

What This Does: Converts the DAG to a linear instruction sequence via priority-aware topological sort.

Why This Matters: The graph structure doesn’t specify execution order. We need to flatten it while respecting dependencies. Priorities ensure sensible ordering (definitions before uses, loads before computation, stores after).

Function: linearize(sink)

OperationPriorityWhy
DEFINE_GLOBAL-20Arguments must be defined first
DEFINE_VAR-19Variables must be defined first
DEFINE_LOCAL-18Allocations first
DEFINE_REG-17Registers first
CONST-10Constants early for reuse
LOAD-1Loads before use
END-5Closes ranges
STORE+1Stores after computation
RANGE+5Ranges open before use

Lower priority = earlier in sequence. This ensures:

  • Definitions come first
  • Loads happen before computation
  • Stores happen last
  • Ranges open before their contents, close after

Run_count ordering: Operations are sorted primarily by execution frequency (run_count), then by priority. Operations with lower execution frequency (outside inner loops) are scheduled first, while operations in inner loops (higher run_count) are scheduled later. Example: A CONST executed 100 times appears before a CONST executed 1M times.

run_count Calculation:

run_count = prod(int(r.vmax) + 1 for r in u.ranges)

This computes how many times an operation executes based on its enclosing ranges.

Morok: schedule/src/linearize/mod.rs


Stage 22: Cleanup IF/ENDIF

Stage at a Glance

Goal: Final cleanup of linear instruction list Key Transformation: Gated INDEX β†’ IF/STORE/ENDIF Impact: Handles hardware without predicated stores

What This Does: Final cleanup of the linear instruction list.

Why This Matters: Some hardware (modern GPUs) supports β€œpredicated stores”—write to memory only if condition is true. Older hardware doesn’t. For those, we wrap store in an IF statement. This stage ONLY runs when hardware lacks predicated store support.

Pattern: pm_linearize_cleanups (via line_rewrite, not graph_rewrite)

// Gated INDEX in STORE becomes conditional store
STORE(INDEX(ptr, idx, valid=cond), value)
β†’ IF(cond) { STORE(INDEX(ptr, idx), value) } ENDIF

Note: This stage uses line_rewrite instead of graph_rewrite because it operates on the already-linearized instruction list rather than a DAG.

At this point, the instruction list is ready for code generation.

Morok: schedule/src/linearize/mod.rs (predicated stores path)


Worked Example: Tracing Through All 22 Stages

Let’s trace c = a + b (where a, b are [100, 100] tensors) through the pipeline.

Initial Tensor Graph

[ADD]
β”œβ”€β”€ [BUFFER(a)] : Float32
└── [BUFFER(b)] : Float32

After Stage 1: Early Movement Ops

(No changeβ€”no movement ops in this example)

After Stage 2: Load Collapse

(No changeβ€”no reductions in this example)

After Stage 3: Split Ranges

(No changeβ€”no modulo operations)

After Stage 4: Initial Symbolic

(No changeβ€”no simplification needed)

After Stage 5: Simplify Ranges

(No changeβ€”no adjacent ranges yet)

After Stage 6: Split Store

(Not applicableβ€”GPU backend)

After Stage 7: Apply Opts

Optimization actions applied:

  • UPCAST j dimension by 4 (vectorization)
  • LOCAL for input buffers (if beneficial)

After Stage 8: Post-Opt Symbolic

No changesβ€”symbolic already clean.

After Stage 9: Expander

UPCAST β†’ UNROLL β†’ CONTRACT:

[VECTORIZE]
β”œβ”€β”€ [ADD]
β”‚   β”œβ”€β”€ [LOAD(a)]
β”‚   β”‚   └── [INDEX]
β”‚   β”‚       β”œβ”€β”€ [BUFFER(a)]
β”‚   β”‚       β”œβ”€β”€ [RANGE(i, Global, 0..100)]
β”‚   β”‚       └── [UNROLL(VCONST([0,1,2,3]))]  // Converted from RANGE(j, UPCAST)
β”‚   └── [LOAD(b)]
β”‚       └── [INDEX]
β”‚           β”œβ”€β”€ [BUFFER(b)]
β”‚           β”œβ”€β”€ [RANGE(i)]  // Same RANGE via hash consing
β”‚           └── [UNROLL(VCONST([0,1,2,3]))]  // Same UNROLL via hash consing

After Stage 10: Add Local Buffers

(If LOCAL opt was chosen)

After Stage 11: Remove Reduce

(No changeβ€”no reductions)

After Stage 12: Add GPU Dims

[SPECIAL(gidx0)] : Index  // replaces RANGE(i)

After Stage 13: Add Loads

(No changeβ€”loads already present)

After Stage 14: Devectorize

Vector split to match hardware width:

[VECTORIZE] : <4 x Float32>
β”œβ”€β”€ [ADD(a[0], b[0])]
β”œβ”€β”€ [ADD(a[1], b[1])]
β”œβ”€β”€ [ADD(a[2], b[2])]
└── [ADD(a[3], b[3])]

After Stage 15: Lower Index Dtype

[SPECIAL(gidx0)] : i32  // concrete type

After Stage 16: Post-Index Symbolic

No changes needed.

After Stage 17: Pre-Matcher

(No patterns for standard backends)

After Stage 18: Decompositions

No decompositions neededβ€”all ops supported.

After Stage 19: Final Rewrite

No changes needed.

After Stage 20: Add Control Flow

Dependencies trackedβ€”no issues.

After Stage 21: Linearize

Linear instruction sequence (simplified):

1. DEFINE_GLOBAL(0)  // Output buffer c
2. DEFINE_GLOBAL(1)  // Input buffer a
3. DEFINE_GLOBAL(2)  // Input buffer b
4. RANGE(i, 0..100, Global)  // gidx0
5. RANGE(j, 0..25, Loop)  // Unrolled /4
6. LOAD(a, i, j*4+0)  // Vector load
7. LOAD(b, i, j*4+0)  // Vector load
8. ADD(vec_a, vec_b)  // Vector add
9. STORE(c, i, j*4+0, result)
10. END(RANGE(j))
11. END(RANGE(i))

After Stage 22: Cleanup IF/ENDIF

No changes neededβ€”no gated stores.

Result: Ready for code generation! The LLVM/CUDA/other backend will compile this to actual machine code.


Pattern Application Strategy

Each stage uses one of two rewrite strategies:

Top-down (default): Process parents before children. Use when transformations create new matchable subterms.

Bottom-up: Process children before parents. Use when child state affects parent matching (stages 1, 20).

Both iterate to fixpointβ€”patterns fire until no more match.


Debugging the Pipeline

When a kernel produces wrong results, the bug lives in one of these 22 stages. Use environment variables to extract IR at each stage:

# See IR after each transformation
MOROK_DEBUG=ir cargo test failing_test

Quick Reference

SymptomLikely StagesWhat to Check
Wrong values in output4, 9, 11, 18Symbolic simplification, expansion, devectorization
Slow performance7, 9, 14, 21Optimization, expansion, devectorization, linearization
Crashes/panics11, 12Reduce, GPU dims
Wrong loop count3, 5, 12Split ranges, simplify ranges, GPU dims
Missing vectorization9, 14Expander, devectorize

Common Issues

  1. Stage 3-4: Range splitting/symbolic may lose constraints
  2. Stage 9: Expansion order affects vectorization correctness
  3. Stage 11: Accumulator initialization must match reduction identity
  4. Stage 14: Hardware width mismatchβ€”check vector fold length
  5. Stage 18: Missing decompositionβ€”check supported_ops list for backend
  6. Stage 21: Priority bugs cause data racesβ€”verify dependencies

Summary

The 22-stage pipeline transforms tensor expressions into machine code through systematic refinement:

  1. Stages 1-7: Make iteration explicit, optimize ranges
  2. Stages 8-10: Expand optimization primitives
  3. Stages 11-15: Lower to hardware-specific operations
  4. Stages 16-22: Serialize to executable instructions

Each stage has a single responsibility. Each builds on the last. The result: high-level tensor code runs at near-optimal speed on diverse hardware.


Tinygrad vs Morok: Architectural Differences

This chapter describes the β€œideal” 22-stage pipeline based on Tinygrad’s implementation. Morok now closely follows this design with minimal differences.

Remaining Architectural Differences

StageTinygradMorokNotes
1: Early Movement OpsMoves movement ops through AFTER/END wrappersRemoves movement ops during bufferizationBoth approaches achieve functional equivalence; Morok’s is cleaner

Aligned Stages (Previously Different)

The following stages were aligned with Tinygrad as of this implementation:

StageWhat Changed
15: Index Dtype LoweringMorok now has pm_lower_index_dtype() with full pattern coverage: Binary ops, CONST, WHERE, VECTORIZE, SPECIAL, DEFINE_VAR, RANGE, CAST cleanup
18: DecompositionsAdded: fast_division_patterns(), pm_div_to_shr(), pm_fdiv_to_mul(), pm_comparison_negations(), De Morgan’s laws
19: Final Rewritepm_render() moved from codegen to Stage 19 in schedule pipeline

Tinygrad-Only Patterns

Morok intentionally does not implement these Tinygrad-specific patterns:

PatternPurposeWhy Morok Doesn’t Need It
to_bufferviewAvoid disk buffer copies for DISK/TINYFS devicesMorok doesn’t support DISK/TINYFS; in-memory backends don’t need this
AFTER/END movement patternsMove movement ops through timing wrappersMorok removes movement ops during bufferization instead

Morok Enhancements

Morok has some patterns/enhancements not in Tinygrad:

EnhancementLocationPurpose
Nested INDEX flattening with identical indicesmovement_op_patterns()Removes redundant INDEX(INDEX(ptr, [i]), [i])
CAT β†’ VECTORIZEpm_renderConverts CAT to explicit VECTORIZE (can’t render CAT directly)
PTRCAT([x]) unwrappm_renderRemoves single-element PTRCAT wrappers
GEP through CAST/BITCASTgep_pushing_patterns()Pushes GEP through type casts for better optimization
Image dtype guardpm_add_loads()Skips LOAD wrapping for Image dtype (handled in codegen)

Glossary

TermSimple DefinitionExample
AccumulatorVariable holding running totalacc = acc + value (in reduction)
AxisOne dimension of a tensorShape [100, 200] has 2 axes
AxisTypeHow a loop executesGlobal=parallel, Reduce=accumulate
BufferAllocated memory holding dataA tensor’s data lives in a buffer
BufferizeStore result in memory instead of computing on-demandMaterialize intermediate value
CONTRACTCombine multiple values into one vector[a, b, c, d] β†’ vec4(a,b,c,d)
DevectorizeSplit vectors to match hardwarevec8 β†’ vec4, vec4
DivmodDivision and remainder operationsx // 7, x % 7
FixpointWhen applying patterns no longer changes anythingPatterns fire until fixpoint
GEPGet Element Pointerβ€”compute address from indicesarr[i][j] β†’ base + i*stride + j
Hash consingReuse identical expressionsADD(x, 0) + ADD(x, 0) shares memory
IndexInteger type for array indicesi32 or i64, depending on device
LoadRead from memoryvalue = arr[i]
PatternFind-and-replace rule for codeADD(x, 0) β†’ x
Predicated storeWrite to memory conditionallyWrite if valid else skip
RangeLoop iteration specificationfor i in 0..100
ReductionCombine many values into oneSum, max, min
StoreWrite to memoryarr[i] = value
SymbolicSimplify using algebra rules(x/4)*4 β†’ x (when x%4=0)
Tensor coreHardware for fast matrix multiplyNVIDIA GPUs only
Topological sortOrder nodes respecting dependenciesA before B if B uses A’s result
UNROLLExpand one op into multiple positionsx β†’ [x_0, x_1, x_2, x_3]
UPCASTMark intent to vectorizeRANGE(0..4, UPCAST)
VectorizeProcess multiple values togetherSIMD: add 4 numbers at once
WHEREConditional selectionWHERE(cond, x, y) = x if cond else y

One IR to Rule Them All

You’re debugging a slow model. The profiler says β€œkernel X takes 200ms” but you have no idea what kernel X actually does. You trace through PyTorch’s dispatcher, then ATen, then TorchInductor, then Triton IR, and finally land in LLVM IR. Five different representations, five different mental models, five different debugging tools.

This is the reality of modern ML compilation. TensorFlow’s XLA has a similar story: Python β†’ Graph β†’ XLA HLO β†’ MLIR β†’ LLVM IR. Each layer was added to solve a real problem, but the accumulated complexity is staggering.

Morok takes a different approach, borrowed from Tinygrad: one IR from tensors to machine code.

β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”   β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚    TensorFlow    β”‚   β”‚     PyTorch     β”‚   β”‚     Morok     β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚   Python API     β”‚   β”‚   Python API    β”‚   β”‚  Rust/Python  β”‚
β”‚   TF Graph       β”‚   β”‚   FX Graph      β”‚   β”‚       ↓       β”‚
β”‚   XLA HLO        β”‚   β”‚   Inductor IR   β”‚   β”‚    UOp IR     β”‚
β”‚   MLIR dialects  β”‚   β”‚   Triton IR     β”‚   β”‚       ↓       β”‚
β”‚   LLVM IR        β”‚   β”‚   LLVM/PTX      β”‚   β”‚  Machine code β”‚
β”‚   Machine code   β”‚   β”‚   Machine code  β”‚   β”‚               β”‚
β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€   β”œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€
β”‚      5 IRs       β”‚   β”‚      4 IRs      β”‚   β”‚     1 IR      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜   β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

The simplest architecture often wins. This chapter explains how one carefully designed IR can replace an entire compiler stack.


UOp: The Universal Node

A UOp (micro-operation) is a node in a computation graph. But unlike nodes in other IRs, a UOp can represent operations at any abstraction levelβ€”from high-level tensor reshapes down to individual CPU instructions.

Here’s the key insight: instead of having separate IRs for β€œtensor operations” and β€œloop structures” and β€œmemory accesses”, we put them all in one enum:

#![allow(unused)]
fn main() {
pub enum Op {
    // High-level tensor operations
    Reshape { src: Arc<UOp>, new_shape: Arc<UOp> },
    Permute { src: Arc<UOp>, axes: Vec<usize> },
    ReduceAxis { src: Arc<UOp>, reduce_op: ReduceOp, axes: Vec<usize> },

    // Loop-level control flow
    Range { end: Arc<UOp>, axis_id: AxisId, axis_type: AxisType },
    End { computation: Arc<UOp>, ranges: SmallVec<[Arc<UOp>; 4]> },

    // Memory operations
    Load { buffer: Arc<UOp>, index: Arc<UOp> },
    Store { buffer: Arc<UOp>, index: Arc<UOp>, value: Arc<UOp>, ... },

    // ALU operations (same as hardware)
    Binary(BinaryOp, Arc<UOp>, Arc<UOp>),  // Add, Mul, etc.
    Unary(UnaryOp, Arc<UOp>),              // Sqrt, Exp, etc.
}
}

The enum has ~80 variants organized by abstraction level:

CategoryExamplesWhat It Represents
MovementRESHAPE, PERMUTE, EXPAND, PADTensor shape transformations
ReductionREDUCE_AXIS, REDUCEMathematical aggregations
ControlRANGE, END, IF, BARRIERLoop and branch structure
MemoryLOAD, STORE, INDEX, BUFFERHardware memory access
ALUADD, MUL, SQRT, EXP, WHERECPU/GPU instructions
AdvancedWMMA, CONTRACT, UNROLLTensor cores, vectorization

When you print a UOp graph, you see its tree structure:

[42] STORE : Void
β”œβ”€β”€ [10] DEFINE_GLOBAL(0) : Ptr<Float32>
β”œβ”€β”€ [35] INDEX : Ptr<Float32>
β”‚   β”œβ”€β”€ [10] β†’ (same as above)
β”‚   └── [30] RANGE(axis=0, Reduce) : Index
β”‚       └── [5] CONST(4) : Index
└── [40] REDUCE(Add) : Float32
    β”œβ”€β”€ [38] MUL : Float32
    β”‚   β”œβ”€β”€ [36] LOAD : Float32
    β”‚   └── [37] LOAD : Float32
    └── [30] β†’ (same RANGE as above)

Notice the arrows pointing to β€œsame as above”? That’s not just pretty-printingβ€”it’s a fundamental property called hash consing.


Hash Consing: Structural Sharing

When you create the same expression twice in Morok, you get the same pointer. Not equal valuesβ€”the same memory address.

#![allow(unused)]
fn main() {
let a = UOp::binary(Add, x.clone(), y.clone());
let b = UOp::binary(Add, x.clone(), y.clone());

assert!(Arc::ptr_eq(&a, &b));  // Same pointer!
}

This works through a global cache. When constructing a UOp, we first check if an identical one exists:

#![allow(unused)]
fn main() {
pub fn new(op: Op, dtype: DType) -> Arc<Self> {
    let key = UOpKey::new(&op, dtype);

    // Check cache first
    if let Some(existing) = CACHE.get(&key) {
        return existing;
    }

    // Create new and cache it
    let uop = Arc::new(UOp { op, dtype, ... });
    CACHE.insert(key, uop.clone());
    uop
}
}

Why does this matter for ML engineers?

  • Pointer equality is semantic equality. To check if two subexpressions are identical, just compare pointers: Arc::ptr_eq(&a, &b). No tree traversal needed.

  • Pattern matching is O(1). When the optimizer asks β€œhave I seen this pattern before?”, pointer comparison gives an instant answer.

  • Memory efficiency. Common subexpressions (think: shared computations in attention, gradient graphs) are stored once, not duplicated.

  • Thread safety. The same computation from different threads produces the same objectβ€”no synchronization bugs.

The tree printout shows this: when you see [10] β†’ (same as above), that’s not a copyβ€”it’s the same node referenced from multiple places.


Explicit Loops: The RANGE Operation

Most ML IRs hide loops inside operations. In ONNX, a reduction looks like:

ReduceSum(data, axes=[1], keepdims=0)

Where’s the loop? It’s implicitβ€”somewhere inside the runtime’s implementation of ReduceSum. You can’t see it, can’t modify it, can’t reason about it.

Morok makes loops explicit using RANGE operations. The same reduction becomes:

[REDUCE(Add)]
β”œβ”€β”€ [LOAD]
β”‚   └── [INDEX]
β”‚       β”œβ”€β”€ [BUFFER]
β”‚       β”œβ”€β”€ [RANGE(axis=0, Global)]   # outer loop (parallelized)
β”‚       β”‚   └── [CONST(128)]
β”‚       └── [RANGE(axis=1, Reduce)]   # reduction loop
β”‚           └── [CONST(64)]
└── [RANGE(axis=1, Reduce)]           # same RANGE via hash consing

Each RANGE has an AxisType that tells the code generator how to compile it:

AxisTypeCPUCUDAMeaning
GlobalThread poolblockIdxOuter parallel dimension
Local(N/A)threadIdxWorkgroup parallelism
Loopfor loopfor loopSequential iteration
ReduceAccumulatorWarp reduceReduction dimension
UpcastSIMD vectorRegister tileVectorization
UnrollUnrolledUnrolledLoop unrolling

The AxisType hierarchy (Global β†’ Local β†’ Loop β†’ Reduce β†’ Upcast β†’ Unroll) maps directly to GPU programming models. A RANGE with AxisType::Global becomes blockIdx.x in CUDA. A RANGE with AxisType::Local becomes threadIdx.x.

Why explicit loops matter:

  • Optimization is visible. You can see which loops will be parallelized, which will be unrolled, which will use SIMD.

  • Scheduling is graph rewriting. Changing loop order, tiling, or unrolling is just a pattern transformationβ€”no special β€œscheduling pass”.

  • Same IR at every stage. The RANGE that represents β€œiterate over batch dimension” at the tensor level is the same RANGE that becomes for (int i = 0; i < N; i++) in generated code.


Graph Rewriting: One Transformation Mechanism

Traditional compilers have dozens of specialized passes: constant folding, dead code elimination, loop unrolling, operator fusion. Each pass has custom logic, custom data structures, custom bugs.

Morok uses one mechanism: pattern-based graph rewriting.

#![allow(unused)]
fn main() {
patterns! {
    // Identity folding: x + 0 β†’ x
    Add[x, @zero] ~> x,

    // Constant folding: 3 + 4 β†’ 7
    Add(a @const(a_val), b @const(b_val))
        => eval_add(a_val, b_val).map(|r| UOp::const_(a.dtype(), r)),

    // Self-folding: x / x β†’ 1
    Idiv(x, x) ~> UOp::one(x.dtype()),

    // Dead code: if(true) { x } else { y } β†’ x
    Where(@true, t, _f) ~> t,
}
}

The DSL is expressive:

  • [x, y] β€” commutative. Try both orderings (for ADD, MUL, etc.)
  • (x, y) β€” ordered. Match exactly this order.
  • @zero, @one, @true β€” semantic constants. Works for any dtype.
  • @const(val) β€” extract value. For compile-time computation.
  • x, x β€” same operand. Detects pointer equality.
  • ~> vs => β€” infallible vs fallible rewrite.

The rewrite engine applies patterns bottom-up until no more matches:

Original:       Add(Mul(x, 1), 0)
After Mul:      Add(x, 0)         # Mul(x, 1) β†’ x
After Add:      x                 # Add(x, 0) β†’ x

This single mechanism handles:

  • Algebraic simplification β€” constant folding, identity removal
  • Rangeify transformation β€” movement ops β†’ explicit loops
  • Kernel optimization β€” vectorization, unrolling, tensor cores
  • Code generation β€” lowering to hardware primitives

Same patterns, same engine, different pattern sets for each stage.


Worked Example: Matmul Journey

Let’s trace C = A @ B (a 4Γ—4 matrix multiply) through the entire pipeline.

Stage 1: Tensor Construction

When you write A.matmul(&B), Morok builds a high-level UOp graph:

[REDUCE_AXIS(Add, axes=[2])]
β”œβ”€β”€ [MUL]
β”‚   β”œβ”€β”€ [EXPAND]           # A: [4,4] β†’ [4,4,4]
β”‚   β”‚   └── [BUFFER(A)]
β”‚   └── [EXPAND]           # B: [4,4] β†’ [4,4,4]
β”‚       └── [PERMUTE]      # transpose for broadcasting
β”‚           └── [BUFFER(B)]

This is pure math: β€œexpand A and B to align dimensions, multiply elementwise, sum along the contracted axis.”

Stage 2: Rangeify

The rangeify pass converts movement ops (EXPAND, PERMUTE) into explicit index computations with RANGE loops:

[STORE]
β”œβ”€β”€ [DEFINE_GLOBAL(C)]
β”œβ”€β”€ [INDEX]
β”‚   β”œβ”€β”€ [DEFINE_GLOBAL(C)]
β”‚   β”œβ”€β”€ [RANGE(i, Global)]     # i ∈ [0, 4)
β”‚   β”‚   └── [CONST(4)]
β”‚   └── [RANGE(j, Global)]     # j ∈ [0, 4)
β”‚       └── [CONST(4)]
└── [REDUCE(Add)]
    β”œβ”€β”€ [MUL]
    β”‚   β”œβ”€β”€ [LOAD(A)]
    β”‚   β”‚   └── [INDEX]
    β”‚   β”‚       β”œβ”€β”€ [RANGE(i)]     # same i (hash consing)
    β”‚   β”‚       └── [RANGE(k, Reduce)]
    β”‚   └── [LOAD(B)]
    β”‚       └── [INDEX]
    β”‚           β”œβ”€β”€ [RANGE(k)]     # same k
    β”‚           └── [RANGE(j)]     # same j
    └── [RANGE(k, Reduce)]         # k ∈ [0, 4)
        └── [CONST(4)]

Now we see the loop structure: i and j are Global (parallelized), k is Reduce (accumulated).

Stage 3: Symbolic Simplification

Pattern rewrites clean up redundant operations, fold constants, and simplify index arithmetic.

Stage 4: Code Generation

The final IR translates directly to loops:

// GPU kernel (conceptual)
__global__ void matmul(float* C, float* A, float* B) {
    int i = blockIdx.x;   // from RANGE(i, Global)
    int j = blockIdx.y;   // from RANGE(j, Global)
    float acc = 0.0f;
    for (int k = 0; k < 4; k++) {  // from RANGE(k, Reduce)
        acc += A[i*4 + k] * B[k*4 + j];
    }
    C[i*4 + j] = acc;
}

The key observation: structure is visible at every stage. There’s no magic fusion pass that turns three nested loops into something unrecognizable. The RANGE structure you see in Stage 2 is exactly what becomes loops in Stage 4.


Comparison: How Other IRs Differ

Different IRs make different tradeoffs. Here’s how they stack up:

AspectONNXXLA HLOTritonMorok
PurposeModel interchangeBackend optimizationGPU kernel DSLFull compilation
Operators~200 high-level~100–150 high-levelTile operations~80 multi-level
Loop modelImplicitImplicitTile-basedExplicit RANGE
MemoryPure valuesPure values β†’ buffersExplicit pointersExplicit LOAD/STORE
OptimizationNoneSpecialized passesMLIR patternsUnified rewriting
TargetsRuntime enginesCPU/GPU/TPUGPU onlyCPU/GPU

ONNX maximizes portability. Operations like Conv and MatMul hide all implementation details. Great for model exchange, but you can’t optimize what you can’t see.

XLA HLO is functional and pureβ€”no side effects, immutable tensors. This enables algebraic optimization but requires a separate β€œbuffer assignment” phase before code generation. The transition from HLO to LMHLO (buffer-based) is a fundamental boundary.

Triton exposes more than ONNX but less than Morok. You write β€œtile-level” codeβ€”operations on blocks of dataβ€”and the compiler handles thread-level details. Explicit memory (tl.load, tl.store) but implicit parallelization within tiles.

Morok exposes everything: loops are explicit (RANGE), memory is explicit (LOAD/STORE), parallelization is explicit (AxisType). This means more to learn, but nothing is hidden.


Why This Matters: Practical Benefits

Morok’s transparent IR has practical benefits for ML engineers:

Debugging is direct. Print the graph at any stage:

#![allow(unused)]
fn main() {
println!("{}", tensor.uop().tree());
}

You’ll see exactly what operations exist, how they connect, and where the computation happens. No β€œkernel X” mysteries.

Performance tuning is informed. See which loops are parallelized:

[RANGE(batch, Global)]    # parallelized across GPU blocks
[RANGE(channel, Local)]   # parallelized within blocks
[RANGE(pixel, Loop)]      # sequential β€” might be slow!

If something should be parallel but isn’t, you can see it.

The mental model is simple. There’s one IR, one transformation mechanism, one set of operations. You don’t need to learn XLA HLO and MLIR and Triton and LLVM. Just UOps.

Optimization is composable. Want a custom rewrite? Add a pattern:

#![allow(unused)]
fn main() {
patterns! {
    // Your custom optimization
    MyPattern(x, y) ~> better_version(x, y),
}
}

It works with the same engine as constant folding, fusion, and everything else.


The Deeper Insight

Morok/Tinygrad proves that compiler complexity is often accidental, not essential. The multi-layer IR stacks in TensorFlow and PyTorch accumulated organicallyβ€”each layer solved a real problem, but the combined system is harder to understand than any individual part.

One well-designed IR, one transformation mechanism, and principled composition can replace thousands of lines of specialized passes. It’s the Unix philosophy applied to compilers: do one thing well, and compose.

The cost is explicitnessβ€”you see loops, memory accesses, and parallelization hints that other IRs hide. But visibility is a feature, not a bug. When your model is slow, you want to see why, not hope the compiler figures it out.

That’s the bet Morok makes: transparent complexity beats hidden complexity.

Pattern-Based Optimization

Open any production ML compiler and you’ll find dozens of optimization passes: constant folding, dead code elimination, operator fusion, loop tiling, vectorization, memory layout optimization. Each pass has its own data structures, its own traversal logic, its own bugs.

Morok takes a different approach: one mechanism for everything.

Traditional Compiler:              Morok:
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”       β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  Constant Folding       β”‚       β”‚                         β”‚
β”‚  Dead Code Elimination  β”‚       β”‚   patterns! {           β”‚
β”‚  Loop Unrolling         β”‚       β”‚       Add[x, @zero] ~> xβ”‚
β”‚  Operator Fusion        β”‚       β”‚       Mul[x, @zero] ~> 0β”‚
β”‚  Vectorization          β”‚       β”‚       // ...more        β”‚
β”‚  Memory Planning        β”‚       β”‚   }                     β”‚
β”‚  ...20 more passes      β”‚       β”‚                         β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜       β”‚   graph_rewrite(...)    β”‚
     Custom logic each            β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
                                       One mechanism

Every optimization in Morok is expressed as a pattern: β€œwhen you see this structure, replace it with that structure.” The same graph_rewrite() function applies constant folding, converts movement ops to loops, optimizes memory access patterns, and lowers to hardware primitives.

This chapter explains how pattern-based optimization works and why it’s powerful.


The patterns! DSL

Morok provides a domain-specific language for writing optimization patterns. Here’s what it looks like:

#![allow(unused)]
fn main() {
patterns! {
    // Identity folding: x + 0 β†’ x
    Add[x, @zero] ~> |x| x.clone(),

    // Constant folding: 3 + 4 β†’ 7
    Add(a @const(a_val), b @const(b_val))
        => |a, a_val, b_val| eval_add(a_val, b_val).map(|r| UOp::const_(a.dtype(), r)),

    // Self-folding: x / x β†’ 1
    Idiv(x, x) ~> |x| UOp::one(x.dtype()),

    // Dead code elimination: if(true) { t } else { f } β†’ t
    Where(@true, t, _f) ~> |t| t.clone(),
}
}

The macro compiles these patterns into efficient Rust code. Let’s break down the syntax:

SyntaxMeaningExample
(x, y)Ordered. Match in exact order.Sub(x, @zero) ~> x
[x, y]Commutative. Try both orderings.Add[x, @zero] ~> x
@zeroZero constant. Matches 0 or 0.0.Mul[_, z @ @zero] ~> z
@oneOne constant. Matches 1 or 1.0.Mul[x, @one] ~> x
@const(val)Extract constant. Binds the value.Add(@const(a), @const(b))
x, xSame operand. Auto-generates ptr_eq check.Idiv(x, x) ~> UOp::one(...)
~>Infallible. Always succeeds, returns Arc<UOp>.Add[x, @zero] ~> x
=>Fallible. May fail, returns Option<Arc<UOp>>.=> eval(...).map(...)
for op in binary [...]Template. Generate patterns for multiple ops.See below
@context TypeStateful. Access mutable context in patterns.See below

Template Expansion

Instead of writing the same pattern for every binary operation, use a for-loop:

#![allow(unused)]
fn main() {
patterns! {
    for op in binary [Add, Mul, Sub, Idiv, Fdiv, Max] {
        op(a @const(a_val), b @const(b_val))
            => |a, a_val, b_val| eval_binary(op, a_val, b_val)
                .map(|r| UOp::const_(a.dtype(), r))
    }
}
}

This expands to six separate patterns at compile timeβ€”one for each operation.

Stateful Patterns

Some optimizations need context (e.g., which kernel we’re in, what ranges are active). Declare a context type:

#![allow(unused)]
fn main() {
patterns! {
    @context KernelContext;

    ReduceAxis { src } => |reduce, src, ctx| {
        ctx.record_reduction(reduce);
        transform_reduce(reduce, src, ctx)
    }
}
}

The context is passed as the last argument to pattern closures.


How Pattern Matching Works

The patterns! macro generates a SimplifiedPatternMatcher that dispatches patterns in O(1) time.

The OpKey Index

Every UOp has an operation type (Add, Mul, Load, etc.). The #[derive(PatternEnum)] macro generates an OpKey enum that maps operations to hashable keys:

#![allow(unused)]
fn main() {
pub enum OpKey {
    Binary(BinaryOp),    // Add, Mul, Sub, ...
    Unary(UnaryOp),      // Neg, Sqrt, Exp, ...
    Ternary(TernaryOp),  // Where, MulAcc
    Const,
    Load,
    Store,
    // ... one variant per operation category
}
}

The Matcher Structure

#![allow(unused)]
fn main() {
pub struct SimplifiedPatternMatcher<C = ()> {
    indexed: HashMap<OpKey, Vec<PatternClosure<C>>>,  // O(1) lookup
    wildcards: Vec<PatternClosure<C>>,                 // patterns matching any op
}
}

When matching a UOp:

  1. Extract OpKey from the UOp’s operation
  2. Lookup in the HashMapβ€”O(1)
  3. Try each closure until one matches
  4. Fall back to wildcards if no indexed pattern matches

This is 5-10x faster than scanning all patterns linearly.

Commutative Handling

For patterns like Add[x, @zero], the macro generates code that tries both orderings:

#![allow(unused)]
fn main() {
// Try (x, @zero)
if let Some(result) = try_match_ordered(&children[0], &children[1]) {
    return result;
}
// Try (@zero, x)
if let Some(result) = try_match_ordered(&children[1], &children[0]) {
    return result;
}
}

Duplicate Detection

When you write Idiv(x, x), the pattern should only match if both operands are the same UOp (pointer equality, not structural equality). The macro automatically generates this check:

#![allow(unused)]
fn main() {
// Generated code for Idiv(x, x)
let x = &children[0];
let x_dup = &children[1];
if !Arc::ptr_eq(x, x_dup) {
    return NoMatch;
}
// ... rest of pattern
}

This leverages hash consingβ€”identical subexpressions share the same pointer.


The Rewrite Engine: Two-Stage Algorithm

Pattern matching alone isn’t enough. Consider this expression:

WHERE(Lt(3, 5), t, f)

To simplify it, we need two steps:

  1. Lt(3, 5) β†’ true (constant folding)
  2. WHERE(true, t, f) β†’ t (dead code elimination)

But the WHERE pattern won’t match until its child is simplified. The rewrite engine solves this with a two-stage algorithm.

Stage 0: Pattern Application

#![allow(unused)]
fn main() {
fn rewrite_stage0(&mut self, uop: &Arc<UOp>) -> RewriteResult {
    match self.matcher.try_match(uop) {
        Some(replacement) => RewriteResult::Rewritten(replacement),
        None => RewriteResult::Gate(uop.clone()),  // process children
    }
}
}

If no pattern matches, return Gateβ€”a signal to process children first.

Stage 1: Source Reconstruction

After children are rewritten, rebuild the node with new children and try patterns again:

#![allow(unused)]
fn main() {
fn rewrite_stage1(&mut self, uop: &Arc<UOp>, new_children: Vec<Arc<UOp>>) {
    // Rebuild with optimized children
    let rebuilt = uop.with_sources(new_children);

    // Try patterns againβ€”might match now!
    match self.matcher.try_match(&rebuilt) {
        Some(replacement) => replacement,
        None => rebuilt,
    }
}
}

The Magic: Cascading Optimizations

Stage 0: WHERE(Lt(3, 5), t, f)     β†’ Gate (no match, process children)
         └── Lt(3, 5)              β†’ true (constant folding matches!)

Stage 1: WHERE(true, t, f)         β†’ t (dead code elimination matches!)

The reconstruction stage re-applies patterns, enabling multi-step optimizations in a single traversal.

Safety Limits

To prevent infinite loops, the engine has limits:

  • 1000 iterations per node maximum
  • 100,000 iterations total maximum
  • Panics with diagnostic info if limits exceeded

In practice, well-formed patterns converge quickly.


The Full Optimization Pipeline

Pattern matching is one part of a larger pipeline. When you call tensor.realize(), here’s what happens:

Tensor.realize()
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  RANGEIFY                                               β”‚
β”‚  Convert movement ops (RESHAPE, PERMUTE, EXPAND)        β”‚
β”‚  into explicit RANGE loops with INDEX operations        β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  KERNEL SPLITTING                                       β”‚
β”‚  Split computation graph at STORE boundaries            β”‚
β”‚  Each STORE becomes a separate kernel                   β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  FOR EACH KERNEL:                                       β”‚
β”‚                                                         β”‚
β”‚  1. Symbolic Simplification (algebraic patterns)        β”‚
β”‚                                                         β”‚
β”‚  2. Scheduler Creation                                  β”‚
β”‚     └── Convert LOOP β†’ GLOBAL for GPU parallelization   β”‚
β”‚                                                         β”‚
β”‚  3. Kernel Optimization (heuristic OR beam search)      β”‚
β”‚     β”œβ”€β”€ Tensor Cores (WMMA) for matmul                  β”‚
β”‚     β”œβ”€β”€ Vectorization (UPCAST)                          β”‚
β”‚     β”œβ”€β”€ Loop Unrolling (UNROLL)                         β”‚
β”‚     β”œβ”€β”€ GPU Local Memory (LOCAL)                        β”‚
β”‚     β”œβ”€β”€ Grouped Reductions (GROUP)                      β”‚
β”‚     └── Threading (THREAD) for CPU                      β”‚
β”‚                                                         β”‚
β”‚  4. Post-Optimization Passes                            β”‚
β”‚     β”œβ”€β”€ Devectorize (memory coalescing)                 β”‚
β”‚     β”œβ”€β”€ Expand (UNROLL β†’ vector operations)             β”‚
β”‚     β”œβ”€β”€ FMA Decomposition (a*b+c β†’ MulAcc)              β”‚
β”‚     └── Bool Storage (cast bool↔uint8 for memory)       β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜
    β”‚
    β–Ό
β”Œβ”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”
β”‚  CODE GENERATION                                        β”‚
β”‚  Render optimized AST to LLVM IR, compile, execute      β”‚
β””β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”€β”˜

Each box uses pattern-based rewriting. The difference is which patterns are applied:

  • Rangeify: Movement op β†’ BUFFERIZE + INDEX patterns
  • Symbolic: Algebraic simplification patterns
  • Post-opt: Memory access optimization patterns

After symbolic simplification, each kernel needs scheduling decisions: how to tile loops, where to parallelize, whether to use tensor cores. Morok offers two strategies.

Heuristics (Default)

The heuristic optimizer applies optimizations in a fixed order:

#![allow(unused)]
fn main() {
pub fn hand_coded_optimizations(scheduler: &mut Scheduler) {
    // 1. Tensor cores (if matmul pattern detected)
    if let Some(tc) = detect_tensor_core_pattern(scheduler) {
        apply_tensor_core(scheduler, tc);
        return;  // TC handles everything
    }

    // 2. Grouped reductions (two-stage for large reductions)
    apply_grouped_reduction_if_needed(scheduler);

    // 3. Vectorization (UPCAST output dimensions)
    apply_upcast(scheduler, 4);

    // 4. GPU local memory (workgroup dimensions)
    apply_local_dims(scheduler);

    // 5. CPU threading
    apply_threading(scheduler);
}
}

Pros: Fast (~50ms per kernel), predictable, no hardware measurement needed.

Cons: May miss optimization opportunities, fixed heuristics don’t adapt to workload.

Beam Search (Optional)

For production workloads, beam search finds better schedules:

#![allow(unused)]
fn main() {
pub fn beam_search(scheduler: Scheduler, config: BeamConfig) -> Scheduler {
    let mut beam = vec![scheduler];

    for iteration in 0..config.max_iterations {
        let mut candidates = vec![];

        for state in &beam {
            // Generate all valid next actions
            for action in generate_actions(state) {
                if let Ok(next) = state.apply(action) {
                    candidates.push(next);
                }
            }
        }

        // Compile and time each candidate
        let timed: Vec<_> = candidates.par_iter()
            .map(|c| (c, measure_kernel_time(c)))
            .collect();

        // Keep top K by execution time
        beam = timed.into_iter()
            .sorted_by_key(|(_, time)| *time)
            .take(config.beam_width)
            .map(|(c, _)| c)
            .collect();
    }

    beam.into_iter().next().unwrap()
}
}

The action space includes ~500 predefined actions:

  • UPCAST(axis, amount) β€” vectorize output dimension
  • UNROLL(axis, amount) β€” unroll reduction loop
  • LOCAL(axis, amount) β€” use GPU shared memory
  • GROUP(axis, amount) β€” two-stage reduction
  • THREAD(axis, amount) β€” CPU parallelization
  • SWAP(axis1, axis2) β€” reorder global dimensions

Pros: Finds near-optimal schedules, adapts to hardware.

Cons: Minutes per kernel (but results are cached by AST hash).

Configuration

# Disable optimization (debugging)
MOROK_NOOPT=1 cargo run

# Enable beam search with width 8
MOROK_BEAM=8 cargo run

Or programmatically:

#![allow(unused)]
fn main() {
let config = OptimizerConfig::builder()
    .strategy(OptStrategy::Beam { width: 8 })
    .build();

tensor.realize_with(config)?;
}

Comparison: How Other Compilers Optimize

Different ML compilers take different approaches to optimization:

AspectXLATVM/AnsorTritonMorok
PhilosophyFixed heuristicsSearch-basedProgrammer-guidedPattern-based
FusionConservative rulesTile-and-fuseBlock-levelGraph rewriting
Auto-tuningNoneEvolutionary + cost modelGrid searchBeam search
Tuning cost0HoursMinutesMinutes (cached)
FlexibilityLowHighMediumHigh
TransparencyLow (C++ passes)Medium (Python)Medium (DSL)High (patterns!)

XLA β€” Production Conservative

XLA uses fixed heuristics for fusion decisions. Safe and predictable, but leaves performance on the table. The fusion rules are hard-coded in C++β€”extending them requires deep compiler knowledge.

TVM/Ansor β€” Maximum Auto-Tuning

TVM separates what to compute from how to compute it. Ansor uses evolutionary search with a learned cost model to explore the schedule space. Can achieve best-in-class performance, but tuning takes hours per model.

Triton β€” Programmer-Guided

Triton exposes a Python-like DSL where you write blocked algorithms explicitly. The compiler handles register allocation and memory management. Good balance of control and automation, but requires GPU programming expertise.

Morok β€” Pattern Composition

Morok’s insight: express optimizations as composable patterns. Each pattern is local and verifiable. Complex optimizations emerge from composition. Beam search adds auto-tuning when needed, with results cached for reuse.


Why This Matters: Practical Benefits

Pattern-based optimization has concrete advantages for developers:

Debugging is direct. Patterns are readable code. Add a println! to any pattern to trace when it fires:

#![allow(unused)]
fn main() {
patterns! {
    Add[x, @zero] ~> |x| {
        println!("Folding add-zero: {:?}", x);
        x.clone()
    }
}
}

Extensibility is easy. Adding a custom optimization is two lines:

#![allow(unused)]
fn main() {
patterns! {
    // Your domain-specific optimization
    MyOp(x, y) if is_special_case(x, y) ~> transform(x, y)
}
}

No need to understand compiler internals, write visitors, or modify pass managers.

Correctness is local. Each pattern is a small theorem: β€œif this structure appears, replacing it with that structure preserves semantics.” Verify each pattern independently. Composition of correct patterns yields correct programs.

Performance is tunable. O(1) pattern dispatch is fast by default. Enable beam search for production workloads. Cache results by AST hashβ€”tune once, benefit forever.


The Deeper Insight

Pattern matching trades generality for composability.

A general-purpose optimization pass can do anythingβ€”but that’s exactly the problem. It’s hard to verify, hard to extend, hard to compose with other passes. Ordering matters. Interactions are subtle.

A pattern is constrained: it matches a specific structure and produces a specific replacement. But constraints enable composition. Run patterns in any orderβ€”the result converges to the same fixed point. Add new patterns without breaking existing ones. Delete patterns without cascading failures.

Each pattern is a theorem about semantic equivalence. The rewrite engine is a theorem prover, finding derivations from input to optimized output. Correctness follows from the correctness of individual steps.

This is the Unix philosophy applied to compilers: small, focused tools that compose. Pattern-based optimization won’t solve every problemβ€”but for the problems it solves, it solves them elegantly.

Op Bestiary: A Field Guide to UOp Operations

When debugging Morok IR dumps, you’ll encounter operations that aren’t obvious from their names. This chapter documents non-trivial operations with signatures, field explanations, and examples.

What’s covered: Operations that require explanationβ€”loop control, reductions, memory operations, kernel structure, vectorization, tensor cores.

What’s NOT covered: Trivial ALU operations (Add, Mul, Sqrt, etc.) that work exactly as you’d expect.


Loop Control: RANGE and END

RANGE β€” Loop Scope Opener

#![allow(unused)]
fn main() {
Range {
    end: Arc<UOp>,           // loop bound (exclusive)
    axis_id: AxisId,         // identifier for deduplication
    axis_type: AxisType,     // scheduling behavior
}
}

Fields:

FieldTypePurpose
endArc<UOp>Upper bound (exclusive), typically a CONST
axis_idAxisIdUnrenumbered(n) before kernel splitting, Renumbered(n) after
axis_typeAxisTypeDetermines how the loop is scheduled (see below)

AxisType Hierarchy:

TypePriorityGPU MappingPurpose
Outer-2β€”Kernel boundary marker
Loop-1for loopSequential iteration
Global0blockIdxGrid parallelism
Thread0thread poolCPU parallelism
Warp1warp/wavefrontSub-group parallelism
Local2threadIdxWorkgroup parallelism
GroupReduce2shared memoryTwo-stage reduction
Upcast3SIMDVectorization
Reduce4accumulatorReduction dimension
Unroll5unrolledLoop unrolling

Priority determines loop nesting orderβ€”lower values are outer loops.

Example:

RANGE(end=128, axis_id=R0, type=Global)
└── CONST(128) : Index

END β€” Loop Scope Closer

#![allow(unused)]
fn main() {
End {
    computation: Arc<UOp>,              // value computed inside loop
    ranges: SmallVec<[Arc<UOp>; 4]>,    // ranges being closed
}
}

END closes one or more RANGE scopes and removes them from the active set. Multiple ranges can be closed simultaneously.

Example:

END
β”œβ”€β”€ STORE(...)           β€” computation
β”œβ”€β”€ RANGE(R0, Global)    β€” first range closed
└── RANGE(R1, Local)     β€” second range closed

Reduction: REDUCE vs REDUCE_AXIS

Two operations with similar names serve different purposes.

REDUCE_AXIS β€” Tensor Dimension Reduction (High-Level)

#![allow(unused)]
fn main() {
ReduceAxis {
    src: Arc<UOp>,           // input tensor
    reduce_op: ReduceOp,     // Add, Mul, Max, Min
    axes: Vec<usize>,        // axes to reduce
}
}

Used before rangeify. Operates on tensor dimensions like NumPy’s .sum(axis=0).

Example:

REDUCE_AXIS(Add, axes=[1])
└── BUFFER[10, 20] : Float32

This reduces a [10, 20] tensor to [10] by summing along axis 1.

REDUCE β€” Range Iteration Reduction (Low-Level)

#![allow(unused)]
fn main() {
Reduce {
    src: Arc<UOp>,                      // value to accumulate
    ranges: SmallVec<[Arc<UOp>; 4]>,    // ranges being reduced
    reduce_op: ReduceOp,                // Add, Mul, Max, Min
}
}

Used after rangeify. Accumulates values across RANGE iterations and closes the specified ranges.

ReduceOp Variants:

OpIdentityOperationTinygrad
Add0acc + valueβœ“
Mul1acc * valueβœ“
Max-∞max(acc, value)βœ“
Min+∞min(acc, value)Morok-only

Compatibility: Tinygrad’s spec restricts REDUCE_AXIS to {Add, Mul, Max}. Morok extends this with Min.

Example:

REDUCE(Add)
β”œβ”€β”€ MUL                      β€” value to accumulate
β”‚   β”œβ”€β”€ LOAD(A, ...)
β”‚   └── LOAD(B, ...)
└── RANGE(R2, Reduce)        β€” range being reduced
    └── CONST(64)

ALLREDUCE β€” Cross-Device Reduction

#![allow(unused)]
fn main() {
AllReduce {
    src: Arc<UOp>,           // local partial result
    device: Arc<UOp>,        // device specification
    reduce_op: ReduceOp,     // reduction operation
}
}

Performs distributed reduction across multiple devices. Used for multi-GPU training.


Buffer Operations

BUFFER β€” Buffer Declaration

#![allow(unused)]
fn main() {
Buffer {
    unique: Arc<UOp>,        // UNIQUE op for identity
    device: Arc<UOp>,        // DEVICE op
    size: usize,             // total element count
}
}

Declares a buffer for tensor storage. The unique field ensures distinct buffers even with identical size/device.

BUFFERIZE β€” Materialization Marker

#![allow(unused)]
fn main() {
Bufferize {
    compute: Arc<UOp>,                  // computation to materialize
    ranges: SmallVec<[Arc<UOp>; 4]>,    // output dimensions
    opts: BufferizeOpts,                // address space, device
}
}

Marks where computation should materialize to memory. Triggers kernel splitting.

BufferizeOpts:

FieldTypePurpose
deviceOption<DeviceSpec>Target device, None for local
addrspaceAddrSpaceGlobal (device) or Local (shared)

Example:

BUFFERIZE(opts={addrspace=Global})
β”œβ”€β”€ REDUCE(Add, ...)         β€” computation
β”œβ”€β”€ RANGE(R0, Global)        β€” output dim 0
└── RANGE(R1, Global)        β€” output dim 1

INDEX β€” Multi-Dimensional Buffer Access

#![allow(unused)]
fn main() {
Index {
    buffer: Arc<UOp>,                   // BUFFER or DEFINE_GLOBAL
    indices: SmallVec<[Arc<UOp>; 4]>,   // index per dimension
    gate: Option<Arc<UOp>>,             // optional predicate
}
}

Computes memory address from multi-dimensional indices. Returns element dtype (not pointer).

Example:

INDEX : Float32
β”œβ”€β”€ DEFINE_GLOBAL(0)
β”œβ”€β”€ RANGE(R0, Global)        β€” index for dim 0
β”œβ”€β”€ RANGE(R1, Loop)          β€” index for dim 1
└── MUL(...)                 β€” index for dim 2

POINTER_INDEX β€” Low-Level Pointer Arithmetic

#![allow(unused)]
fn main() {
PointerIndex {
    ptr: Arc<UOp>,           // base pointer
    offset: Arc<UOp>,        // byte offset
}
}

Direct pointer arithmetic. Used after linearization when indices are flattened.

Compatibility: Tinygrad uses INDEX with a ptr=True flag instead of a separate operation.

LOAD β€” Memory Read

#![allow(unused)]
fn main() {
Load {
    buffer: Arc<UOp>,        // buffer or pointer
    index: Arc<UOp>,         // INDEX op
}
}

Read value from buffer at index. For gated loads, use an INDEX with a gate (INDEX has an optional gate field).

Example:

LOAD : Float32
β”œβ”€β”€ DEFINE_GLOBAL(1)
└── INDEX
    β”œβ”€β”€ DEFINE_GLOBAL(1)
    β”œβ”€β”€ RANGE(R0)
    └── RANGE(R2)

STORE β€” Memory Write

#![allow(unused)]
fn main() {
Store {
    buffer: Arc<UOp>,                   // output buffer
    index: Arc<UOp>,                    // INDEX op
    value: Arc<UOp>,                    // value to write
    ranges: SmallVec<[Arc<UOp>; 4]>,    // ranges being closed
}
}

Write value to buffer. STORE closes the specified ranges, which represent output iteration dimensions. The ranges field is used for output upcasting: when a Range(Upcast) is included, it becomes UNROLL during expansion, then contracted via CONTRACT.

For gated stores, use an INDEX with a gate (INDEX has an optional gate field).

Compatibility: Morok’s STORE has an explicit index field (sources: buffer=0, index=1, value=2, ranges=3+). Tinygrad’s STORE combines buffer and value differently (range_start=2).

Example:

STORE
β”œβ”€β”€ DEFINE_GLOBAL(0)         β€” output buffer
β”œβ”€β”€ INDEX[R0, R1]            β€” write address
β”œβ”€β”€ REDUCE(Add, ...)         β€” value
β”œβ”€β”€ RANGE(R0, Global)        β€” output dim 0 (closed)
└── RANGE(R1, Global)        β€” output dim 1 (closed)

Kernel Structure

KERNEL β€” Kernel Wrapper

#![allow(unused)]
fn main() {
Kernel {
    sources: SmallVec<[Arc<UOp>; 4]>,   // arguments
    ast: Arc<UOp>,                       // computation (usually SINK)
}
}

Wraps a complete kernel for code generation. Sources are kernel arguments (DefineGlobal, DefineLocal, DefineVar).

Example:

KERNEL
β”œβ”€β”€ DEFINE_GLOBAL(0)         β€” output buffer arg
β”œβ”€β”€ DEFINE_GLOBAL(1)         β€” input A arg
β”œβ”€β”€ DEFINE_GLOBAL(2)         β€” input B arg
└── SINK                     β€” computation
    └── STORE(...)

SINK β€” Multiple Root Collector

#![allow(unused)]
fn main() {
Sink {
    sources: SmallVec<[Arc<UOp>; 4]>,
}
}

Collects multiple outputs into a single root. Every kernel’s ast is typically a SINK containing STORE operations.

Example:

SINK
β”œβ”€β”€ STORE(output_0, ...)
β”œβ”€β”€ STORE(output_1, ...)
└── STORE(output_2, ...)

AFTER β€” Dependency Marker

#![allow(unused)]
fn main() {
After {
    passthrough: Arc<UOp>,              // value that flows through
    deps: SmallVec<[Arc<UOp>; 4]>,      // operations that must complete
}
}

Expresses execution dependencies between kernels without data dependency. The passthrough value is returned unchanged, but only after all deps complete.

Example:

SINK
β”œβ”€β”€ AFTER
β”‚   β”œβ”€β”€ DEFINE_GLOBAL(0)     β€” passthrough (buffer reference)
β”‚   └── KERNEL(...)          β€” must complete first
└── KERNEL(...)              β€” can use buffer after AFTER

BARRIER β€” Synchronization Fence

#![allow(unused)]
fn main() {
Barrier {
    src: Arc<UOp>,                      // value passing through
    deps: SmallVec<[Arc<UOp>; 4]>,      // operations to wait for
}
}

GPU workgroup synchronization. Ensures all threads in a workgroup reach the barrier before continuing.


Vector Operations

VECTORIZE β€” Create Vector from Scalars

#![allow(unused)]
fn main() {
Vectorize {
    elements: SmallVec<[Arc<UOp>; 4]>,
}
}

Combines N scalar values into a vector of size N. All elements must have the same base dtype.

Example:

VECTORIZE : <4 x Float32>
β”œβ”€β”€ CONST(1.0)
β”œβ”€β”€ CONST(2.0)
β”œβ”€β”€ CONST(3.0)
└── CONST(4.0)

GEP β€” Get Element Pointer (Vector Extract)

#![allow(unused)]
fn main() {
Gep {
    vector: Arc<UOp>,        // source vector
    indices: Vec<usize>,     // positions to extract
}
}

Extracts elements from a vector:

  • Single index β†’ scalar
  • Multiple indices β†’ smaller vector

Example:

GEP([0, 2]) : <2 x Float32>
└── VECTORIZE : <4 x Float32>
    └── ...

VConst β€” Vector Constant

#![allow(unused)]
fn main() {
VConst {
    values: Vec<ConstValue>,
}
}

Vector of compile-time constants. More efficient than VECTORIZE of CONST nodes.

CAT β€” Concatenate Vectors

#![allow(unused)]
fn main() {
Cat {
    sources: SmallVec<[Arc<UOp>; 4]>,
}
}

Concatenates vectors into a larger vector. Output vcount = sum of input vcounts.

Example:

CAT : <8 x Float32>
β”œβ”€β”€ VECTORIZE : <4 x Float32>
└── VECTORIZE : <4 x Float32>

PtrCat β€” Concatenate Pointers

#![allow(unused)]
fn main() {
PtrCat {
    sources: SmallVec<[Arc<UOp>; 4]>,
}
}

Groups memory accesses for vectorized load/store. Used by the devectorizer pass.


Expansion: UNROLL and CONTRACT

UNROLL β€” Expand Computation Across Iterations

#![allow(unused)]
fn main() {
Unroll {
    src: Arc<UOp>,                       // computation to expand
    unroll_axes: Vec<(usize, usize)>,    // (axis_index, factor) pairs
}
}

Creates multiple versions of computation for different iteration values. Used for loop unrolling optimization.

Example: UNROLL(unroll_axes=[(0, 4)]) expands computation 4 times with different index values.

CONTRACT β€” Collapse Unrolled Values to Vector

#![allow(unused)]
fn main() {
Contract {
    src: Arc<UOp>,                       // unrolled computation
    upcast_ranges: Vec<(usize, usize)>,  // (axis_index, factor) pairs
}
}

The inverse of UNROLLβ€”collects expanded scalar values into a vector. Output vector size = product of factors.

Example:

CONTRACT(upcast_ranges=[(0, 4)]) : <4 x Float32>
└── UNROLL(unroll_axes=[(0, 4)])
    └── LOAD(...)

This pattern vectorizes a load: expand 4 iterations, then pack results into a 4-element vector.


Tensor Cores: WMMA

WMMA β€” Warp Matrix Multiply-Accumulate

#![allow(unused)]
fn main() {
Wmma {
    a: Arc<UOp>,             // matrix A fragment
    b: Arc<UOp>,             // matrix B fragment
    c: Arc<UOp>,             // accumulator C fragment
    metadata: WmmaMetadata,  // hardware configuration
}
}

Hardware tensor core operation: D = A Γ— B + C. Requires specific matrix shapes and data layouts.

WmmaMetadata Fields:

FieldTypePurpose
nameStringInstruction name (e.g., "__hmma...")
dims(N, M, K)Matrix dimensions (e.g., (16, 16, 16))
dtype_inDTypeInput matrix precision (e.g., Float16)
dtype_outDTypeOutput precision (e.g., Float32)
deviceStringTarget device string
threadsusizeThreads per warp (typically 32)
upcast_axesVec<(usize, usize)>Vectorization for output
reduce_axesVec<(usize, usize)>Contraction axes

Example:

WMMA(dims=(16, 16, 16), dtype_in=Float16, dtype_out=Float32)
β”œβ”€β”€ A fragment : <8 x Float16>
β”œβ”€β”€ B fragment : <8 x Float16>
└── C accumulator : <8 x Float32>

Control Flow

IF / ENDIF β€” Conditional Execution

#![allow(unused)]
fn main() {
If {
    condition: Arc<UOp>,                // boolean predicate
    body: SmallVec<[Arc<UOp>; 4]>,      // operations to execute
}

EndIf {
    if_op: Arc<UOp>,         // corresponding IF op
}
}

Execute body only when condition is true. Used for boundary checks and sparse operations.

Example:

IF
β”œβ”€β”€ LT(idx, bound)           β€” condition (src[0])
β”œβ”€β”€ STORE(...)               β€” body[0]
└── STORE(...)               β€” body[1]

ENDIF
└── IF(...)                  β€” references IF op

Definition Operations

DEFINE_GLOBAL β€” Device Memory Argument

#![allow(unused)]
fn main() {
DefineGlobal(usize)          // argument index
}

Kernel argument for device (global) memory. Index refers to position in kernel argument list.

DEFINE_LOCAL β€” Shared Memory Allocation

#![allow(unused)]
fn main() {
DefineLocal(usize)           // local memory index
}

GPU shared memory (LDS) allocation. Visible within a workgroup.

DEFINE_VAR β€” Symbolic Runtime Variable

#![allow(unused)]
fn main() {
DefineVar {
    name: String,            // variable name
    min_val: i64,            // minimum bound
    max_val: i64,            // maximum bound
}
}

Runtime variable with known bounds. Used for dynamic shapes where bounds are known.

Example:

DEFINE_VAR(name="batch_size", min=1, max=128) : Index

DEFINE_REG β€” Register Allocation

#![allow(unused)]
fn main() {
DefineReg {
    size: usize,             // register size
}
}

Allocates a register for intermediate storage. Used in code generation.

BIND β€” Variable Binding

#![allow(unused)]
fn main() {
Bind {
    var: Arc<UOp>,           // DEFINE_VAR
    value: Arc<UOp>,         // concrete value
}
}

Binds a symbolic variable to a concrete value at runtime.


Special Operations

SPECIAL β€” Hardware-Provided Values

#![allow(unused)]
fn main() {
Special {
    end: Arc<UOp>,           // upper bound for this dimension
    name: String,            // e.g., "blockIdx.x", "threadIdx.y"
}
}

Accesses hardware-provided values (thread/block indices). Not a loopβ€”the hardware provides the value directly.

Example:

SPECIAL(name="blockIdx.x", end=128) : Index
└── CONST(128)

UNIQUE β€” Identity Marker

#![allow(unused)]
fn main() {
Unique(usize)                // unique identifier
}

Creates a unique identity for buffer disambiguation. Two buffers with different UNIQUE values are distinct even if otherwise identical.

DEVICE β€” Device Specification

#![allow(unused)]
fn main() {
Device(DeviceSpec)           // device specification
}

Specifies target device for computation.


Movement Operations

High-level tensor shape transformations. These are converted to explicit INDEX operations during rangeify.

OperationSignaturePurpose
Reshape{ src, new_shape }Change shape, same elements
Permute{ src, axes: Vec<usize> }Transpose/reorder axes
Expand{ src, new_shape }Broadcast to larger shape
Pad{ src, begin_pads, end_pads }Add padding
Shrink{ src, begins, ends }Extract sub-region
Flip{ src, axes: Vec<bool> }Reverse along axes

Example: RESHAPE

RESHAPE(new_shape=[6, 4]) : Shape[6, 4]
β”œβ”€β”€ BUFFER[2, 3, 4] : Float32
└── CONST([6, 4]) : Shape

Quick Reference

By Category

CategoryOperations
Loop ControlRANGE, END
ReductionREDUCE_AXIS, REDUCE, ALLREDUCE
MemoryBUFFER, BUFFERIZE, INDEX, POINTER_INDEX, LOAD, STORE
KernelKERNEL, SINK, AFTER, BARRIER
VectorVECTORIZE, GEP, VCONST, CAT, PTRCAT
ExpansionUNROLL, CONTRACT
HardwareWMMA, SPECIAL
ControlIF, ENDIF
DefinitionDEFINE_GLOBAL, DEFINE_LOCAL, DEFINE_VAR, DEFINE_REG, BIND, UNIQUE, DEVICE
MovementRESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP
ALUUnary(...), Binary(...), Ternary(...), Cast, BitCast

Range-Ending Operations

Operations that close RANGE scopes (remove ranges from active set):

OperationRange Start Index
BUFFERIZE1 (compute=0, ranges=1+)
REDUCE1 (src=0, ranges=1+)
STORE3 (buffer=0, index=1, value=2, ranges=3+)
WMMA3 (a=0, b=1, c=2)
END1 (computation=0, ranges=1+)

Expandable Operations

Operations that propagate UNROLL through the computation graph:

  • ALU: Unary, Binary, Ternary
  • Type: Cast, BitCast
  • Vector: Gep, Vectorize
  • Memory: Load, Store, Index, PointerIndex
  • Control: Reduce, End, After
  • Buffer: Bufferize
  • Hardware: Wmma