Morok
β οΈ Pre-alpha software. APIs are unstable and may change without notice. Not recommended for production use. π§π
Morok is a Rust-based ML compiler inspired by Tinygrad. It features lazy tensor evaluation with UOp-based IR, pattern-driven optimization, and multi-backend code generation.
Highlights
| Feature | Description |
|---|---|
| Declarative Optimization | patterns! DSL for graph rewrites with Z3-verified correctness |
| Lazy Evaluation | Tensors build computation graphs, compiled only at realize() |
| CUDA Support | Unified memory, D2D copy, LRU buffer caching |
| Provenance Tracking | #[track_caller] traces every UOp to source location |
| 80+ IR Operations | Arithmetic, memory, control flow, WMMA tensor cores |
| 20+ Optimizations | Constant folding, tensor cores, vectorization, loop unrolling |
Quick Example
#![allow(unused)]
fn main() {
use morok_tensor::Tensor;
// Build lazy computation graph
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
let b = Tensor::from_slice(&[4.0, 5.0, 6.0], &[3])?;
let c = (a + b).sum();
// Compile and execute
let result = c.realize()?;
}
License
MIT
Execution Pipeline
From tensor definition to kernel execution, morok follows a multi-stage compilation pipeline inspired by Tinygrad.
Stage Overview
Tensor API β UOp DAG β Rangeify β Kernel Split β Schedule β Codegen β Execute
Stage 0: Tensor Creation
Input: Rust slice or data
Output: Tensor { uop: Rc<UOp> }
#![allow(unused)]
fn main() {
let a = Tensor::from_slice(&[1.0, 2.0, 3.0], &[3])?;
}
The tensor creates a BUFFER UOp and registers the actual device buffer in a thread-local registry. No computation happens yet.
Stage 1: Lazy Operation Building
Input: Tensor operations Output: UOp DAG (Directed Acyclic Graph)
#![allow(unused)]
fn main() {
let c = a.try_add(&b)?; // Builds UOp graph, no execution
let d = c.sum(); // Adds REDUCE node to graph
}
Each operation appends nodes to the UOp graph. The graph captures:
- Arithmetic operations (ADD, MUL, etc.)
- Movement operations (RESHAPE, PERMUTE, EXPAND, PAD, SHRINK, FLIP)
- Reductions (SUM, MAX, etc.)
Stage 2: Rangeify
Input: UOp DAG with movement ops Output: UOp DAG with BUFFERIZE + INDEX + explicit RANGE loops
File: schedule/src/rangeify/transform.rs
This is the core transformation that converts high-level tensor operations into explicit loop nests:
- Range Assignment: Create RANGE UOps for each dimension
- Movement Op Transformation: Convert movement ops to index expressions
SHRINKβ offset adjustmentPERMUTEβ axis reorderingEXPANDβ broadcast (index becomes 0)PADβ conditional with WHERERESHAPEβ axis combinatorics
- Buffer Simplification: Remove unnecessary intermediate buffers
- Symbolic Simplification: Apply algebraic identities
# Before rangeify:
BUFFER.reshape([2,3]).expand([4,2,3]).sum(axis=0)
# After rangeify:
RANGE(4) -> RANGE(2) -> RANGE(3) ->
LOAD(buffer, index_expr) -> REDUCE(ADD) -> STORE
Stage 3: Kernel Splitting
Input: Rangeified UOp DAG Output: Multiple KERNEL UOps
File: schedule/src/rangeify/pipeline.rs
Splits the graph at STORE boundaries into separate kernels:
- BUFFERIZE β STORE: Convert buffer materializations to explicit stores
- STORE β KERNEL: Group related stores into kernel operations
- Dependency Tracking: AFTER operations mark cross-kernel dependencies
Stage 4: Schedule Creation
Input: KERNEL UOps
Output: Vec<ScheduleItem> with ordered kernels
Collects kernels in execution order, gathering:
- Kernel AST (the computation graph)
- Buffer arguments (inputs/outputs)
- Symbolic variable values
Stage 5: Codegen
Input: Kernel AST Output: LLVM IR
File: codegen/src/llvm/renderer.rs
Currently implements direct rendering to LLVM IR. Tinygrad has 16+ optimization passes here that morok is still developing:
- Range splitting/flattening
- GPU dimension assignment
- Load/store optimization
- Devectorization
- Control flow insertion
Stage 6: Execution
Input: LLVM IR Output: Computed result
File: runtime/src/llvm.rs
JIT compiles the LLVM IR and executes with buffer pointers.
Gap Analysis vs Tinygrad
| Stage | Morok Status | Missing |
|---|---|---|
| Tensor Creation | Basic | numpy, URL, disk loading |
| Lazy Ops | Partial | Many advanced ops |
| Rangeify | 7 passes | 13+ in Tinygrad |
| Kernel Split | Basic | Dependency-aware BFS |
| Schedule | Basic | Memory planning, var tracking |
| Codegen | Direct render | 16+ optimization passes |
| Execute | LLVM only | CUDA, Metal, WebGPU |
IR Design Philosophy
Morok uses a single unified IR (UOp) for all compilation stages, following Tinygradβs design philosophy.
The UOp Structure
#![allow(unused)]
fn main() {
pub struct UOp {
pub id: u64, // Unique stable ID
pub(crate) op: Op, // Operation (Rust tagged union)
pub(crate) dtype: DType, // Data type
// Cached properties (computed lazily)
pub(crate) shape_cache: OnceCell<...>,
pub(crate) vmin_vmax_cache: OnceCell<...>,
}
}
Key fields:
op: AnOpenum with 80+ operations spanning all abstraction levelsdtype: Type information (scalars, vectors, pointers)id: Unique identifier for caching and provenance tracking
Why One IR?
Traditional ML compilers use multiple IRs:
- TensorFlow: Graph β XLA HLO β MLIR β LLVM IR
- PyTorch: Python AST β TorchScript β FX Graph β Inductor IR β Triton
- JAX: Python β Jaxpr β StableHLO β MHLO β platform IR
Morok/Tinygrad uses one IR that represents all abstraction levels:
Same UOp at Different Stages
# High-level (after tensor ops):
BUFFER.reshape([2,3]).reduce(ADD, axis=0)
# Mid-level (after rangeify):
RANGE(2) -> RANGE(3) -> LOAD -> REDUCE -> STORE -> END
# Low-level (after codegen passes):
DEFINE_GLOBAL -> INDEX -> LOAD -> ADD -> STORE
Enabling Factors
-
Graph Rewriting as Universal Mechanism
Pattern matching + rewriting handles all transformations:
#![allow(unused)] fn main() { let optimized = graph_rewrite(&patterns, uop_graph, &mut ctx); } -
Hash Consing (Structural Sharing)
Identical subgraphs share memory via a thread-local cache:
#![allow(unused)] fn main() { thread_local! { static CACHE: RefCell<HashMap<UOpKey, Weak<UOp>>> = ...; } }Benefits:
- O(1) equality checking (pointer comparison)
- No duplicate subgraphs in memory
- Pattern matching can use pointer identity
-
Lazy Property Computation
Expensive analyses computed once and cached:
#![allow(unused)] fn main() { pub(crate) shape_cache: OnceCell<Result<Option<Shape>>>, pub(crate) vmin_vmax_cache: OnceCell<(ConstValue, ConstValue)>, } -
Operation Hierarchy
Ops organized by level to support progressive lowering:
#![allow(unused)] fn main() { impl Op { pub fn is_movement(&self) -> bool { ... } pub fn is_buffer(&self) -> bool { ... } pub fn is_alu(&self) -> bool { ... } } }
Trade-offs
| Aspect | Single IR | Multi-IR |
|---|---|---|
| Complexity | Lower | Higher |
| Translation bugs | None | Possible |
| Cross-level optimization | Natural | Requires bridging |
| Compile-time safety | Runtime checks | Per-IR guarantees |
| Codebase size | ~15k lines | 100k+ lines |
Morok vs Tinygrad UOp
| Aspect | Tinygrad | Morok |
|---|---|---|
| Language | Python dataclass | Rust struct |
| Children | src: tuple[UOp, ...] | Encoded in Op variants |
| Type safety | Runtime | Compile-time |
| Extra data | arg: Any (untyped) | Typed per variant |
| Memory | Weakref + GC | Rc<UOp> + explicit cleanup |
Morokβs Rust implementation adds compile-time guarantees:
#![allow(unused)]
fn main() {
// Each Op variant encodes its exact structure
Op::Binary(BinaryOp::Add, lhs, rhs) // vs Tinygrad's (Ops.ADD, (lhs, rhs))
Op::Reduce { src, ranges, reduce_op } // Named fields, typed
}
Optimization System
Morokβs optimizer is built on pattern matching and graph rewriting.
UPat: Universal Pattern
#![allow(unused)]
fn main() {
pub enum UPat {
Match {
op: Option<Vec<OpFilter>>, // Operations to match
dtype: Option<Vec<DType>>, // Types to match
src: Option<SrcPattern>, // Child patterns
arg: Option<ArgPattern>, // Argument constraints
name: Option<String>, // Binding name
},
Any(Vec<UPat>), // OR-pattern
}
}
Source Pattern Variants
#![allow(unused)]
fn main() {
pub enum SrcPattern {
Tuple(Vec<UPat>), // Fixed arity: Add(x, y)
Repeat(Box<UPat>), // All match: Sink(stores..)
Fork(Vec<Vec<UPat>>), // OR over arities
Permute(Vec<UPat>), // Commutative: Add[x, y]
}
}
Pattern Matching Algorithm
- Check operation type against
OpFilter - Check dtype against allowed list
- Check argument via
ArgPredicate(IsZero, IsOne, etc.) - Bind or verify named variables (pointer equality)
- Match sources recursively based on
SrcPattern
Commutative Matching
For Add[x, @zero], both orderings are tried:
#![allow(unused)]
fn main() {
// Fast path for binary (n=2)
if patterns[0].matches(children[0]) && patterns[1].matches(children[1]) { return true; }
if patterns[0].matches(children[1]) && patterns[1].matches(children[0]) { return true; }
}
PatternMatcher Indexing
Patterns indexed by operation for O(1) lookup:
#![allow(unused)]
fn main() {
struct PatternMatcher<C> {
patterns: Vec<(UPat, VarIntern, RewriteFn<C>)>,
pdict: HashMap<OpKey, Vec<usize>>, // op -> pattern indices
wildcard_indices: Vec<usize>, // patterns matching any op
}
}
Rewrite Engine
Fixed-point iteration with 2-stage algorithm:
#![allow(unused)]
fn main() {
enum Stage { BottomUp, SourceReconstruction }
fn rewrite(&mut self, root: Rc<UOp>) -> Rc<UOp> {
let mut stack = vec![(root, Stage::BottomUp)];
while let Some((uop, stage)) = stack.pop() {
match stage {
Stage::BottomUp => {
// Apply patterns, push children
}
Stage::SourceReconstruction => {
// Rebuild with rewritten children
// Apply patterns again (enables multi-stage opts)
}
}
}
}
}
Multi-Stage Example
WHERE(Lt(3, 5), t, f)
β [constant fold Lt] β WHERE(true, t, f)
β [DCE] β t
The reconstruction stage re-applies patterns, enabling cascading optimizations.
The patterns! DSL
Proc-macro generates efficient Rust code:
#![allow(unused)]
fn main() {
let matcher = patterns! {
// Commutative identity
Add[x, @zero] ~> x,
// Constant folding with for-loop
for op in binary [Add, Mul, Sub] {
op(a @const(av), b @const(bv))
=> eval_binary_op(op, av, bv).map(|r| UOp::const_(a.dtype(), r)),
},
// Self-pattern (auto ptr_eq)
And(x, x) ~> Rc::clone(x),
};
}
Compile-Time Optimizations
- Variable Index Resolution: Names β u8 indices at macro expansion
- Duplicate Detection:
Add(x, x)generatesRc::ptr_eqcheck - Binding Storage:
SmallVec<[(u8, Rc<UOp>); 4]>(stack for β€4 bindings)
Optimization Categories
| Category | Patterns | Examples |
|---|---|---|
| Constant Folding | 22 | Add(3, 5) β 8 |
| Identity | 8 | x + 0 β x, x * 1 β x |
| Zero Propagation | 4 | x * 0 β 0 |
| Self-Folding | 6 | x / x β 1, x & x β x |
| ALU Folding | 4 | (x + c1) + c2 β x + (c1+c2) |
| Division | 5 | (a*b)/b β a |
| DCE | 6 | WHERE(true, t, f) β t |
| Tensor Core | 3 | TC matching, swizzle, apply |
| Vectorization | - | Upcasting to float4, etc. |
| Loop Unrolling | - | Reductions β€ 32 |