Keyboard shortcuts

Press ← or β†’ to navigate between chapters

Press S or / to search in the book

Press ? to show this help

Press Esc to hide this help

Morok

⚠️ Pre-alpha software. APIs are unstable and may change without notice. Not recommended for production use. πŸš§πŸ’€

Morok is a Rust-based ML compiler inspired by Tinygrad. It features lazy tensor evaluation with UOp-based IR, pattern-driven optimization, and multi-backend code generation.

Highlights

FeatureDescription
Declarative Optimizationpatterns! DSL for graph rewrites with Z3-verified correctness
Lazy EvaluationTensors build computation graphs, compiled only at realize()
CUDA SupportUnified memory, D2D copy, LRU buffer caching
Provenance Tracking#[track_caller] traces every UOp to source location
80+ IR OperationsArithmetic, memory, control flow, WMMA tensor cores
20+ OptimizationsConstant folding, tensor cores, vectorization, loop unrolling

Quick Example

#![allow(unused)]
fn main() {
use morok_tensor::Tensor;

// Build lazy computation graph
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], &[3])?;
let c = (a + b).sum();

// Compile and execute
let result = c.realize()?;
}

License

MIT

Execution Pipeline

From tensor definition to kernel execution, morok follows a multi-stage compilation pipeline inspired by Tinygrad.

Stage Overview

Tensor API β†’ UOp DAG β†’ Rangeify β†’ Kernel Split β†’ Schedule β†’ Codegen β†’ Execute

Stage 0: Tensor Creation

Input: Rust slice or data Output: Tensor { uop: Rc<UOp> }

#![allow(unused)]
fn main() {
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
}

The tensor creates a BUFFER UOp and registers the actual device buffer in a thread-local registry. No computation happens yet.

Stage 1: Lazy Operation Building

Input: Tensor operations Output: UOp DAG (Directed Acyclic Graph)

#![allow(unused)]
fn main() {
let c = a.try_add(&b)?;  // Builds UOp graph, no execution
let d = c.sum();         // Adds REDUCE node to graph
}

Each operation appends nodes to the UOp graph. The graph captures:

  • Arithmetic operations (ADD, MUL, etc.)
  • Movement operations (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP)
  • Reductions (SUM, MAX, etc.)

Stage 2: Rangeify

Input: UOp DAG with movement ops Output: UOp DAG with BUFFERIZE + INDEX + explicit RANGE loops

File: schedule/src/rangeify/transform.rs

This is the core transformation that converts high-level tensor operations into explicit loop nests:

  1. Range Assignment: Create RANGE UOps for each dimension
  2. Movement Op Transformation: Convert movement ops to index expressions
    • SHRINK β†’ offset adjustment
    • PERMUTE β†’ axis reordering
    • EXPAND β†’ broadcast (index becomes 0)
    • PAD β†’ conditional with WHERE
    • RESHAPE β†’ axis combinatorics
  3. Buffer Simplification: Remove unnecessary intermediate buffers
  4. Symbolic Simplification: Apply algebraic identities
# Before rangeify:
BUFFER.reshape([2,3]).expand([4,2,3]).sum(axis=0)

# After rangeify:
RANGE(4) -> RANGE(2) -> RANGE(3) ->
  LOAD(buffer, index_expr) -> REDUCE(ADD) -> STORE

Stage 3: Kernel Splitting

Input: Rangeified UOp DAG Output: Multiple KERNEL UOps

File: schedule/src/rangeify/pipeline.rs

Splits the graph at STORE boundaries into separate kernels:

  1. BUFFERIZE β†’ STORE: Convert buffer materializations to explicit stores
  2. STORE β†’ KERNEL: Group related stores into kernel operations
  3. Dependency Tracking: AFTER operations mark cross-kernel dependencies

Stage 4: Schedule Creation

Input: KERNEL UOps Output: Vec<ScheduleItem> with ordered kernels

Collects kernels in execution order, gathering:

  • Kernel AST (the computation graph)
  • Buffer arguments (inputs/outputs)
  • Symbolic variable values

Stage 5: Codegen

Input: Kernel AST Output: LLVM IR

File: codegen/src/llvm/renderer.rs

Currently implements direct rendering to LLVM IR. Tinygrad has 16+ optimization passes here that morok is still developing:

  • Range splitting/flattening
  • GPU dimension assignment
  • Load/store optimization
  • Devectorization
  • Control flow insertion

Stage 6: Execution

Input: LLVM IR Output: Computed result

File: runtime/src/llvm.rs

JIT compiles the LLVM IR and executes with buffer pointers.

Gap Analysis vs Tinygrad

StageMorok StatusMissing
Tensor CreationBasicnumpy, URL, disk loading
Lazy OpsPartialMany advanced ops
Rangeify7 passes13+ in Tinygrad
Kernel SplitBasicDependency-aware BFS
ScheduleBasicMemory planning, var tracking
CodegenDirect render16+ optimization passes
ExecuteLLVM onlyCUDA, Metal, WebGPU

IR Design Philosophy

Morok uses a single unified IR (UOp) for all compilation stages, following Tinygrad’s design philosophy.

The UOp Structure

#![allow(unused)]
fn main() {
pub struct UOp {
    pub id: u64,                    // Unique stable ID
    pub(crate) op: Op,              // Operation (Rust tagged union)
    pub(crate) dtype: DType,        // Data type
    // Cached properties (computed lazily)
    pub(crate) shape_cache: OnceCell<...>,
    pub(crate) vmin_vmax_cache: OnceCell<...>,
}
}

Key fields:

  • op: An Op enum with 80+ operations spanning all abstraction levels
  • dtype: Type information (scalars, vectors, pointers)
  • id: Unique identifier for caching and provenance tracking

Why One IR?

Traditional ML compilers use multiple IRs:

  • TensorFlow: Graph β†’ XLA HLO β†’ MLIR β†’ LLVM IR
  • PyTorch: Python AST β†’ TorchScript β†’ FX Graph β†’ Inductor IR β†’ Triton
  • JAX: Python β†’ Jaxpr β†’ StableHLO β†’ MHLO β†’ platform IR

Morok/Tinygrad uses one IR that represents all abstraction levels:

Same UOp at Different Stages

# High-level (after tensor ops):
BUFFER.reshape([2,3]).reduce(ADD, axis=0)

# Mid-level (after rangeify):
RANGE(2) -> RANGE(3) -> LOAD -> REDUCE -> STORE -> END

# Low-level (after codegen passes):
DEFINE_GLOBAL -> INDEX -> LOAD -> ADD -> STORE

Enabling Factors

  1. Graph Rewriting as Universal Mechanism

    Pattern matching + rewriting handles all transformations:

    #![allow(unused)]
    fn main() {
    let optimized = graph_rewrite(&patterns, uop_graph, &mut ctx);
    }
  2. Hash Consing (Structural Sharing)

    Identical subgraphs share memory via a thread-local cache:

    #![allow(unused)]
    fn main() {
    thread_local! {
        static CACHE: RefCell<HashMap<UOpKey, Weak<UOp>>> = ...;
    }
    }

    Benefits:

    • O(1) equality checking (pointer comparison)
    • No duplicate subgraphs in memory
    • Pattern matching can use pointer identity
  3. Lazy Property Computation

    Expensive analyses computed once and cached:

    #![allow(unused)]
    fn main() {
    pub(crate) shape_cache: OnceCell<Result<Option<Shape>>>,
    pub(crate) vmin_vmax_cache: OnceCell<(ConstValue, ConstValue)>,
    }
  4. Operation Hierarchy

    Ops organized by level to support progressive lowering:

    #![allow(unused)]
    fn main() {
    impl Op {
        pub fn is_movement(&self) -> bool { ... }
        pub fn is_buffer(&self) -> bool { ... }
        pub fn is_alu(&self) -> bool { ... }
    }
    }

Trade-offs

AspectSingle IRMulti-IR
ComplexityLowerHigher
Translation bugsNonePossible
Cross-level optimizationNaturalRequires bridging
Compile-time safetyRuntime checksPer-IR guarantees
Codebase size~15k lines100k+ lines

Morok vs Tinygrad UOp

AspectTinygradMorok
LanguagePython dataclassRust struct
Childrensrc: tuple[UOp, ...]Encoded in Op variants
Type safetyRuntimeCompile-time
Extra dataarg: Any (untyped)Typed per variant
MemoryWeakref + GCRc<UOp> + explicit cleanup

Morok’s Rust implementation adds compile-time guarantees:

#![allow(unused)]
fn main() {
// Each Op variant encodes its exact structure
Op::Binary(BinaryOp::Add, lhs, rhs)  // vs Tinygrad's (Ops.ADD, (lhs, rhs))
Op::Reduce { src, ranges, reduce_op } // Named fields, typed
}

Optimization System

Morok’s optimizer is built on pattern matching and graph rewriting.

UPat: Universal Pattern

#![allow(unused)]
fn main() {
pub enum UPat {
    Match {
        op: Option<Vec<OpFilter>>,      // Operations to match
        dtype: Option<Vec<DType>>,      // Types to match
        src: Option<SrcPattern>,        // Child patterns
        arg: Option<ArgPattern>,        // Argument constraints
        name: Option<String>,           // Binding name
    },
    Any(Vec<UPat>),  // OR-pattern
}
}

Source Pattern Variants

#![allow(unused)]
fn main() {
pub enum SrcPattern {
    Tuple(Vec<UPat>),      // Fixed arity: Add(x, y)
    Repeat(Box<UPat>),     // All match: Sink(stores..)
    Fork(Vec<Vec<UPat>>),  // OR over arities
    Permute(Vec<UPat>),    // Commutative: Add[x, y]
}
}

Pattern Matching Algorithm

  1. Check operation type against OpFilter
  2. Check dtype against allowed list
  3. Check argument via ArgPredicate (IsZero, IsOne, etc.)
  4. Bind or verify named variables (pointer equality)
  5. Match sources recursively based on SrcPattern

Commutative Matching

For Add[x, @zero], both orderings are tried:

#![allow(unused)]
fn main() {
// Fast path for binary (n=2)
if patterns[0].matches(children[0]) && patterns[1].matches(children[1]) { return true; }
if patterns[0].matches(children[1]) && patterns[1].matches(children[0]) { return true; }
}

PatternMatcher Indexing

Patterns indexed by operation for O(1) lookup:

#![allow(unused)]
fn main() {
struct PatternMatcher<C> {
    patterns: Vec<(UPat, VarIntern, RewriteFn<C>)>,
    pdict: HashMap<OpKey, Vec<usize>>,  // op -> pattern indices
    wildcard_indices: Vec<usize>,       // patterns matching any op
}
}

Rewrite Engine

Fixed-point iteration with 2-stage algorithm:

#![allow(unused)]
fn main() {
enum Stage { BottomUp, SourceReconstruction }

fn rewrite(&mut self, root: Rc<UOp>) -> Rc<UOp> {
    let mut stack = vec![(root, Stage::BottomUp)];
    while let Some((uop, stage)) = stack.pop() {
        match stage {
            Stage::BottomUp => {
                // Apply patterns, push children
            }
            Stage::SourceReconstruction => {
                // Rebuild with rewritten children
                // Apply patterns again (enables multi-stage opts)
            }
        }
    }
}
}

Multi-Stage Example

WHERE(Lt(3, 5), t, f)
  β†’ [constant fold Lt] β†’ WHERE(true, t, f)
  β†’ [DCE] β†’ t

The reconstruction stage re-applies patterns, enabling cascading optimizations.

The patterns! DSL

Proc-macro generates efficient Rust code:

#![allow(unused)]
fn main() {
let matcher = patterns! {
    // Commutative identity
    Add[x, @zero] ~> x,

    // Constant folding with for-loop
    for op in binary [Add, Mul, Sub] {
        op(a @const(av), b @const(bv))
          => eval_binary_op(op, av, bv).map(|r| UOp::const_(a.dtype(), r)),
    },

    // Self-pattern (auto ptr_eq)
    And(x, x) ~> Rc::clone(x),
};
}

Compile-Time Optimizations

  1. Variable Index Resolution: Names β†’ u8 indices at macro expansion
  2. Duplicate Detection: Add(x, x) generates Rc::ptr_eq check
  3. Binding Storage: SmallVec<[(u8, Rc<UOp>); 4]> (stack for ≀4 bindings)

Optimization Categories

CategoryPatternsExamples
Constant Folding22Add(3, 5) β†’ 8
Identity8x + 0 β†’ x, x * 1 β†’ x
Zero Propagation4x * 0 β†’ 0
Self-Folding6x / x β†’ 1, x & x β†’ x
ALU Folding4(x + c1) + c2 β†’ x + (c1+c2)
Division5(a*b)/b β†’ a
DCE6WHERE(true, t, f) β†’ t
Tensor Core3TC matching, swizzle, apply
Vectorization-Upcasting to float4, etc.
Loop Unrolling-Reductions ≀ 32