Перейти к основному содержимому

Паттерны алгебраического упрощения

Символьный упрощитель Morok перезаписывает вычислительные графы UOp с помощью 140+ алгебраических паттернов, определённых в schedule/src/symbolic/patterns.rs. Эти паттерны срабатывают в нескольких точках пайплайна:

ГдеМатчерКонтекст
Пре-оптимизацияsymbolic()После rangeify + range splitting, перед оптимизацией ядра
Пост-оптимизация (Стадия 8)symbolic()После оптимизационных действий, перед раскрытием
Пост-индекс (Стадия 16)symbolic()После снижения типа индексов, финальная очистка
Декомпозиция+Рендер (Стадия 18-19)symbolic_simple()Совместно с поздними перезаписями и паттернами рендера

symbolic() = symbolic_simple() + паттерны проталкивания GEP. Все стадии, кроме финального прохода декомпозиция+рендер, запускают полный набор symbolic().

Анализ диапазонов: Каждый UOp отслеживает минимальное (vmin) и максимальное (vmax) значения, которые он может принять во время выполнения. Эти границы вычисляются жадно при создании узла на основе границ входов. Многие паттерны используют эти границы для доказательства условий на этапе компиляции (например, «x всегда неотрицателен» или «x < n для всех значений»).

Нотация: OP[a, b] обозначает коммутативный паттерн (пробуются оба порядка операндов). OP(a, b) — упорядоченный. @zero/@one/@const сопоставляются с константными значениями. Когда одно и то же имя переменной появляется дважды (например, Idiv(x, x)), оба операнда должны быть одним и тем же узлом (Arc::ptr_eq — т.е. структурно дедуплицированы через hash consing).

Ссылка на Tinygrad: tinygrad/uop/symbolic.py, tinygrad/uop/divandmod.py


Рабочий пример: каскад оптимизаций

Простое выражение, демонстрирующее, как паттерны компонуются:

Before:
ADD
├── MUL
│ ├── ADD
│ │ ├── x
│ │ └── CONST(0) <- identity
│ └── CONST(1) <- identity
└── ADD
├── CONST(3)
└── CONST(4) <- constant fold

Step 1 (identity): ADD(x, 0) -> x
Step 2 (identity): MUL(x, 1) -> x
Step 3 (const fold): ADD(3, 4) -> CONST(7)
Step 4 (result): ADD(x, 7)

After:
ADD
├── x
└── CONST(7)

Движок перезаписи применяет паттерны снизу вверх: сначала упрощаются потомки, затем родитель повторно сопоставляется. Это позволяет многошаговым каскадам сработать за один обход.


Порядок паттернов

Матчер symbolic_simple() компонует группы паттернов в определённом порядке. Внутри группы паттерны пробуются последовательно до первого совпадения. Группы конкатенируются оператором +:

propagate_invalid -- MUST be first (before x*0=0)
fold_invalid_load_store
constant_folding_dsl_patterns
vconst_folding_patterns
identity_and_zero_patterns
commutative_canonicalization
self_folding_dsl_patterns
zero_folding_dsl_patterns
division_dsl_patterns
cast_dsl_patterns
cast_where_dsl_patterns
term_combining_dsl_patterns
alu_folding_dsl_patterns
advanced_division_dsl_patterns
div_mod_recombine_dsl_patterns
comparison_dsl_patterns
boolean_dsl_patterns
minmax_dsl_patterns
where_bound_patterns
power_dsl_patterns
negation_dsl_patterns
range_based_mod_div_patterns
dce_dsl_patterns
dead_loop_patterns
after_simplification_patterns
pm_move_where_on_load -- WHERE->INDEX embedding for masked loads

1. Свёртка констант

Вычисление операций над константами времени компиляции с использованием арифметики, учитывающей dtype. Результаты соблюдают границы типов (например, Int32 оборачивается на 32 битах).

Tinygrad: symbolic.py:40-118

Скалярные константы

КатегорияОперацииПаттерн
Унарные (7)Neg, Sqrt, Exp2, Log2, Sin, Reciprocal, Truncop(CONST(c)) -> CONST(eval(op, c))
Бинарные (13)Add, Mul, Sub, Mod, Max, Pow, Idiv, Fdiv, And, Or, Xor, Shl, Shrop(CONST(a), CONST(b)) -> CONST(eval(op, a, b))
Тернарные (2)Where, MulAccop(CONST(a), CONST(b), CONST(c)) -> CONST(eval(op, a, b, c))

Векторные константы

ПаттернРезультат
op(VCONST(a), VCONST(b))VCONST(eval(op, a, b)) поэлементно
op(CONST(a), VCONST(b))VCONST(eval(op, broadcast(a), b))
op(VCONST(a), CONST(b))VCONST(eval(op, a, broadcast(b)))
unary_op(VCONST(v))VCONST(eval(op, v)) поэлементно

Свёртка VConst покрывает 11 бинарных операций (исключая Pow и Fdiv) и все 7 унарных.


2. Свёртка тождеств и нулей

ПаттернРезультатПримечание
ADD[x, 0]xКоммутативно
MUL[x, 1]xКоммутативно
OR[x, 0]xКоммутативно
XOR[x, 0]xКоммутативно
SUB(x, 0)xУпорядочено
IDIV(x, 1)xУпорядочено
FDIV(x, 1)xУпорядочено
MOD(x, 1)0Остаток от деления на 1 всегда равен нулю
Floor/Ceil/Trunc/Round(x)xТолько когда x целочисленное (округление — нет операции)
MUL[x, 0]0Только когда НЕ float
AND[_, 0]0Коммутативно

:::caution IEEE 754: MUL на ноль MUL[x, 0] не упрощается для чисел с плавающей точкой, поскольку IEEE 754 требует:

  • NaN * 0 = NaN
  • Inf * 0 = NaN

Защита !x.dtype().is_float() предотвращает эту оптимизацию для типов с плавающей точкой. :::


3. Свёртка по совпадению операндов

Паттерны, где один и тот же операнд появляется с обеих сторон. Используются проверки Arc::ptr_eq (hash consing гарантирует, что структурно равные подвыражения разделяют один указатель).

ПаттернРезультатПримечание
IDIV(x, x)1
IDIV(x, -1)NEG(x)Проверка константы правого операнда
MOD(MOD(x, y), y)MOD(x, y)Идемпотентный mod
AND(x, x)x
OR(x, x)x

4. Свёртка в ноль

ПаттернРезультатПримечание
MOD(x, x)0
LT(x, x)falseНЕ для float (NaN < NaN — false, но защита нужна для корректности)
NE(x, x)falseТолько целочисленные — NaN != NaN даёт true в IEEE 754

5. Упрощение деления

ПаттернРезультатПримечание
FDIV(0.0, 0.0)NaNНеопределённая форма IEEE 754
FDIV(MUL[_, 0], 0)NaNЛюбое нулевое выражение / ноль
FDIV(x, x)1.0Деление float на себя
FDIV(MUL(x, y), y)xСокращение (float)
IDIV(MUL(x, y), y)xСокращение (целочисленное)

:::caution Приоритет паттернов FDIV(0, 0) -> NaN должен стоять перед FDIV(x, x) -> 1 в матчере для получения приоритета. Порядок внутри division_dsl_patterns() это гарантирует. :::


6. Оптимизация CAST

ПаттернРезультатПримечание
CAST(CONST(c), dtype)CONST(c.cast(dtype))Свёртка приведения на этапе компиляции
CAST(x, dtype)xКогда x.dtype() == dtype (нет операции)
CAST(CAST(x, a), b)xКогда x.dtype() == b и a сохраняет все значения b
CAST(CAST(x, a), b)CAST(x, b)Когда a не сужает x (цепочка расширений)
CAST(WHERE(s, a, b), dtype)WHERE(s, CAST(a, dtype), CAST(b, dtype))Проталкивание приведения через ветви

Функция can_safe_cast(to, from) определяет, может ли промежуточный тип вместить все значения. Проверяются разрядность, знаковость и категория float/int.

:::caution Усечение ломает обратные преобразования CAST(CAST(x, i8), i64) НЕ сворачивается до x, когда x имеет тип i64. Промежуточный i8 усекает значения — can_safe_cast(i64, i8) возвращает false, потому что i8 не может вместить все значения i64.

Безопасный пример: CAST(CAST(x, i32), bool) -> CAST(x, bool), когда x имеет тип bool, поскольку i32 может представить и true, и false. :::


7. Объединение термов

ПаттернРезультат
ADD(x, x)MUL(2, x)
ADD(MUL(c1, x), MUL(c2, x))MUL(c1+c2, x)
ADD(MUL(x, c1), MUL(x, c2))MUL(x, c1+c2)

Сопоставляются оба упорядоченных варианта (константа слева или справа от MUL).


8. Свёртка цепочек ALU

Свёртка констант в цепочках ассоциативных операций и вынос констант наружу для канонической формы.

Свёртка констант

ПаттернРезультатПримечание
ADD[ADD[x, c1], c2]ADD(x, c1+c2)Коммутативный внешний Add
MUL[MUL[x, c1], c2]MUL(x, c1*c2)Коммутативный внешний Mul
ADD[SUB(x, c1), c2]ADD(x, c2-c1) или SUB(x, c1-c2)С нормализацией знака
SUB(ADD(x, c1), c2)ADD(x, c1-c2) или SUB(x, c2-c1)С нормализацией знака
SUB(SUB(x, c1), c2)SUB(x, c1+c2)

Вынос констант

ПаттернРезультатПримечание
ADD[ADD[x, c], y]ADD(ADD(x, y), c)Выносит константу наружу; y не должен быть константой

Вынос констант критически важен для извлечения индексов. Он гарантирует, что константы всплывают на самый внешний уровень, позволяя нижестоящим паттернам (вроде упрощения div-mod) видеть чистые формы переменная + смещение.

Канонизация Sub

ПаттернРезультатПримечание
SUB(a, SUB(b, x))ADD(x, SUB(a, b))Экспонирует внутреннюю переменную

Morok сохраняет SUB как полноценную операцию IR (в отличие от Tinygrad, который канонизирует a-b в ADD(a, NEG(b))). Этот паттерн гарантирует, что вложенные SUB не блокируют дальнейшее упрощение.


9. Булева логика

ПаттернРезультатПримечание
NOT(NOT(x))xУстранение двойного отрицания
XOR(x, x)0Самоуничтожение
OR[x, NOT(x)]trueТавтология (только bool)
AND[x, NOT(x)]falseПротиворечие (только bool)
OR[true, x]trueПоглощающий элемент
AND[false, x]falseПоглощающий элемент
AND[true, x]xТождество
OR[false, x]xТождество
AND[NOT(x), NOT(y)]NOT(OR(x, y))Закон де Моргана
OR[NOT(x), NOT(y)]NOT(AND(x, y))Закон де Моргана

Все паттерны с [] коммутативны (пробуются оба порядка операндов).


10. Упрощение сравнений

Самосравнение (не float, ptr_eq)

ОперацияРезультат
LT(x, x), GT(x, x), NE(x, x)false
LE(x, x), GE(x, x), EQ(x, x)true

:::caution Самосравнение для float Паттерны самосравнения защищены проверкой !x.dtype().is_float(). Для чисел с плавающей точкой NaN != NaN даёт true, а NaN == NaN даёт false, поэтому эти тождества не выполняются. :::

На основе констант и диапазонов

ПаттернРезультатПримечание
op(CONST(a), CONST(b))CONST(eval(op, a, b))Прямая свёртка констант
op(x, y) когда границы доказываютtrue или falseComparisonAnalyzer использует vmin/vmax

ComparisonAnalyzer проверяет: если x.vmax < y.vmin, то LT(x, y) доказуемо равно true.

Алгебраические преобразования

ПаттернРезультатПримечание
LT(ADD[c0, x], c1)LT(x, c1-c0)Устранение смещения
LT(NEG(x), NEG(y))LT(y, x)Инверсия знака
LT(IDIV(x, d), c)LT(x, c*d)Подъём деления (d > 0)

Подъём деления для LT(x//d, c) обрабатывает как положительные, так и неположительные c:

  • c > 0: эквивалентно x < c*d
  • c <= 0: эквивалентно x < c*d - (d-1)

11. Устранение Min/Max

ПаттернРезультатПримечание
MAX(x, x)xMax от себя — тождество
MAX(x, y)xКогда x.vmin >= y.vmax (границы доказывают доминирование)
MAX(x, y)yКогда y.vmin >= x.vmax

Используется VminVmaxProperty для анализа диапазонов. Отдельных паттернов для MIN нет — Morok снижает MIN(a,b) до NEG(MAX(NEG(a), NEG(b))) до срабатывания этих паттернов.


12. Оптимизация WHERE

Устранение условия

ПаттернРезультатПримечание
WHERE(cond, t, f)tКогда cond.vmin == cond.vmax == true
WHERE(cond, t, f)fКогда cond.vmin == cond.vmax == false
WHERE(LT(x, c), t, f)tКогда x.vmax < c.vmin (всегда true)
WHERE(LT(x, c), t, f)fКогда x.vmin >= c.vmax (всегда false)

Упрощение ветвей

ПаттернРезультатПримечание
WHERE(_, t, t)tОдинаковые ветви
WHERE(x, true, false)xТождество bool
WHERE(x, false, true)NOT(x)Отрицание bool
WHERE(NOT(cond), t, f)WHERE(cond, f, t)Инверсия условия
WHERE(a, WHERE(b, c, d), d)WHERE(AND(a, b), c, d)Слияние ветвей (ptr_eq на d)

:::caution Защита от Invalid при инверсии условия WHERE(NOT(cond), t, f) -> WHERE(cond, f, t) не применяется, когда f содержит Invalid. Паддинг создаёт структуры WHERE(valid, idx, Invalid), и перестановка переместила бы Invalid в ветвь true, где нижестоящие паттерны не смогут его сопоставить. Проверяются как скалярный Invalid, так и векторизованный VECTORIZE(Invalid, ...).

У Tinygrad та же защита: symbolic.py:201-202. :::


13. Распространение Invalid

Invalid — это сигнальное значение Morok для областей тензора за пределами допустимых границ, создаваемых операциями паддинга. Эти паттерны должны выполняться перед паттернами тождеств вроде x*0=0, иначе маркеры валидности будут уничтожены.

Пример приоритета паттернов

Without ordering: MUL(0, WHERE(cond, x, Invalid)) -> 0 (x*0=0 fires, loses Invalid)
With ordering: MUL(0, WHERE(cond, x, Invalid))
-> WHERE(cond, MUL(0, x), Invalid) (Invalid propagation fires first)
-> WHERE(cond, 0, Invalid) (then x*0=0 is safe)

Слияние WHERE-Invalid

ПаттернРезультат
WHERE(c1, WHERE(c2, x, Inv), Inv)WHERE(AND(c1, c2), x, Inv)
WHERE(c1, WHERE(c2, x, Inv), y)WHERE(AND(c1, c2), x, y)

Многомерный паддинг создаёт вложенные WHERE-Invalid после распространения через линеаризованную индексную арифметику. Слияние до одного уровня обеспечивает потребление pm_lower_index_dtype за один шаг.

Проталкивание операций через WHERE-Invalid

ПаттернРезультатОперации
CAST(WHERE(c, x, Inv))WHERE(c, CAST(x), Inv)
op(WHERE(c, x, Inv), y)WHERE(c, op(x, y), Inv)13 бинарных операций (не сравнения)
op(y, WHERE(c, x, Inv))WHERE(c, op(y, x), Inv)13 бинарных операций (не сравнения)
cmp(WHERE(c, x, Inv), y)cmp(x, y)Lt, Le, Eq, Ne, Gt, Ge
cmp(y, WHERE(c, x, Inv))cmp(y, x)Lt, Le, Eq, Ne, Gt, Ge

Для сравнений WHERE-Invalid отбрасывается — область Invalid уже ограждена ниже по потоку.

Распространение голого Invalid

ПаттернРезультатЗащита
op(Invalid, y)Invalidy.dtype() == DType::Index, только левая позиция

Согласование с Tinygrad: symbolic.py:37. Правосторонний голый Invalid НЕ распространяется, чтобы не загрязнять неиндексные вычисления.

Мёртвые загрузки/записи из Invalid-индексов

ПаттернРезультат
LOAD(INDEX(buf, Invalid))CONST(0)
LOAD(CAST(INDEX(buf, Invalid)))CONST(0)
STORE(INDEX(buf, Invalid), val)NOOP
STORE(CAST(INDEX(buf, Invalid)), val)NOOP

14. Удаление мёртвого кода

Мёртвые диапазоны

ПаттернРезультатПримечание
RANGE(end) где vmax < 0CONST(0)Пустой диапазон (никогда не выполняется)
RANGE(CONST) где vmin == vmaxCONST(vmin)Тривиальный диапазон (одно значение)
END(computation, ranges)END(computation, live_ranges)Фильтрация мёртвых диапазонов из END
END(computation, [])computationВсе диапазоны мертвы — разворачиваем

Мёртвые редукции

Операция редукцииНейтральный элемент
Add0
Mul1
Max-inf (минимум dtype)
Min+inf (максимум dtype)

Когда ВСЕ диапазоны REDUCE мертвы (пусты), REDUCE заменяется своим нейтральным элементом.

Упрощение зависимостей

ПаттернРезультат
AFTER(x, [])x

Отсутствие зависимостей означает отсутствие ограничений порядка.


15. Степень и отрицание

ПаттернРезультат
POW(x, 0)1
POW(x, 1)x
NEG(NEG(x))x

16. Проталкивание GEP

GEP (Get Element Pointer) извлекает элементы из векторов. Эти паттерны проталкивают GEP через другие операции, чтобы достичь источника вектора, позволяя скалярное упрощение после девекторизации.

Включены только в symbolic() (Стадия 4), но не в symbolic_simple() (Стадии 8, 16).

Композиция и извлечение

ПаттернРезультатПримечание
GEP(GEP(x, inner), outer)GEP(x, inner[outer])Композиция вложенных
GEP(VECTORIZE(x,x,x,x), [i])xЧерез broadcast (все ptr_eq)
GEP(VECTORIZE(elems), [i])elems[i]Через VECTORIZE
GEP(scalar, [i])scalarТождество для скаляра (vcount == 1)
GEP(VCONST(vals), [i])CONST(vals[i])Через VConst
GEP(x, [0,1,...,n-1])xУдаление тождества

Проталкивание через операции

ПаттернРезультатЗащита
GEP(op(a, b), idx)op(GEP(a, idx), GEP(b, idx))Бинарные, только Index dtype
GEP(unary(x), idx)unary(GEP(x, idx))Унарные, только Index dtype
GEP(WHERE(c, t, f), idx)WHERE(GEP(c, idx), GEP(t, idx), GEP(f, idx))Только Index dtype
GEP(MULACC(a, b, c), idx)MULACC(GEP(a, idx), GEP(b, idx), GEP(c, idx))Только Index dtype

:::caution Защита по Index dtype предотвращает взрыв графа Проталкивание GEP через ALU-операции ограничено Index dtype (Tinygrad: symbolic.py:167). Без этой защиты комбинация проталкивания GEP с no_vectorized_alu вызывает экспоненциальный рост графа на многомерных ядрах. :::

Проталкивание через структурные операции

ПаттернРезультат
GEP(CAT([a<4>, b<4>]), [5])GEP(b, [1])
GEP(PTRCAT([a, b, c]), [1, 2])PTRCAT([b, c])
GEP(CAST(x, dtype), idx)CAST(GEP(x, idx), scalar_dtype)
GEP(BITCAST(x, dtype), idx)BITCAST(GEP(x, idx), scalar_dtype)
GEP(WMMA(a, b, c), idx)WMMA(GEP(a, ...), GEP(b, ...), GEP(c, ...))
GEP(UNROLL(x, ...), idx)GEP(x, idx)
GEP(void_node, _)void_node

Паттерн WMMA отображает тайловые индексы через оси upcast для извлечения соответствующих подгрупп входных данных.

Обратная сборка

ПаттернРезультат
VECTORIZE(GEP(x,[0]), GEP(x,[1]), ..., GEP(x,[N-1]))GEP(x, [0,1,...,N-1])

Это сворачивает структуры VECTORIZE, созданные no_vectorized_alu, обратно в один GEP, который затем удаляется паттерном тождества.


17. WHERE на LOAD (только Стадия 8)

Функция: pm_move_where_on_load()

Преобразует маскированные загрузки, встраивая условие в операцию INDEX:

Before: WHERE(cond, INDEX(buf, idx), 0)
After: INDEX(buf, WHERE(combined_cond, idx, Invalid))

Это позволяет аппаратную предикацию для маскированных загрузок и устраняет накладные расходы WHERE.

Как это работает

  1. Разделить условие на AND-клаузы
  2. Разбить клаузы на переносимые и остающиеся:
    • Переносимые: все зависимости RANGE в области INDEX, нет внешних зависимостей INDEX
    • Остающиеся: всё остальное
  3. Встроить переносимые клаузы как WHERE(cond, idx, Invalid) в indices[0]
  4. Обернуть во внешний WHERE, если остающиеся клаузы существуют

Поддерживается частичный перенос клауз — переносятся только клаузы, чьи диапазоны находятся в области индекса. Существующие клаузы валидности в indices[0] дедуплицируются.

Инвертированный паттерн WHERE(cond, 0, INDEX(buf, idx)) также обрабатывается через отрицание условия.


18. Коммутативная канонизация

Для коммутативных бинарных операций над Index dtype операнды сортируются по идентификатору UOp (меньший id слева):

ОперацииЗащита
Add, Mul, Max, Eq, Ne, And, Or, Xordtype == DType::Index && b.id < a.id

Без этого математически эквивалентные выражения вроде R1*8000 + R2*16 и R2*16 + R1*8000 не будут дедуплицированы через hash consing, нарушая группировку в expand_vector_index.

Применяется только к Index dtype, чтобы не нарушить слияние векторной математики. Tinygrad: symbolic.py:178-182.


19. Упрощение Div-Mod

Быстрые пути на основе диапазонов

ПаттернРезультатУсловие
MOD(x, n)x0 <= vmin(x) и vmax(x) < n
IDIV(x, n)kВсе значения в диапазоне делятся до одного k
MOD(ADD[MUL[a, m], b], n)MOD(b, n)m == n (вынос кратных)
IDIV(ADD[MUL[a, m], b], n)a + IDIV(b, n)m == n
IDIV(ADD[MUL[a, m], b], n)am == n и 0 <= b < n

Унифицированный движок Div-Mod (fold_divmod_general)

Для IDIV и MOD над Index dtype унифицированный движок пробует правила упрощения в порядке приоритета. Основан на fold_divmod_general из Tinygrad (divandmod.py:8-93).

ПриоритетПравилоОписание
1cancel_divmodДиапазон лежит в одном интервале знаменателя
2remove_nested_mod(a%4 + b)%2 -> (a+b)%2 когда `2
3fold_binary_numeratorЕдинственный терм с диапазоном из 2 значений
4fold_divmod_congruenceМодулярная арифметика конгруэнтности множителей
5gcd_with_remainderВынос общего GCD из числителя
6divide_by_gcdGCD-факторизация с переменным знаменателем
7factor_remainder(d*x+y)//d -> x + y//d (крайний вариант)

Рекомбинация Div-Mod

Паттерны, рекомбинирующие разделённые операции div и mod обратно в исходное выражение:

ПаттернРезультатЗащита
ADD[MOD(x, n), MUL[IDIV(x, n), n]]xptr_eq на x, n
ADD[MOD(IDIV(x, a), c), MUL[IDIV(x, b), c]]IDIV(x, a)a * c == b
ADD[MUL[MOD(x, c1), c2], MUL[IDIV(x, c1), c3]]MUL(x, c2)c1 * c2 == c3
ADD[ADD[y, MOD(x, n)], MUL[IDIV(x, n), n]]ADD(y, x)ptr_eq на x, n
IDIV(ADD[IDIV(a, c1), c2], c3)IDIV(ADD(a, c1*c2), c1*c3)Вложенное деление

Продвинутое деление

ПаттернРезультатПримечание
IDIV(IDIV(a, b), c)IDIV(a, b*c)Композиция вложенного деления
IDIV(expr, d)expr.divides(d)Обобщённое точное деление
IDIV(ADD(a, b), c)IDIV(a, c) + IDIV(b, c)Когда оба делятся нацело
IDIV(SUB(a, b), c)IDIV(a, c) - IDIV(b, c)Когда оба делятся нацело
MUL(c, ADD(a, b))ADD(MUL(c, a), MUL(c, b))Распределение умножения

Перекрёстные ссылки

  • Пайплайн исполнения — стадии, на которых выполняются эти паттерны
  • Движок паттернов — как работает движок сопоставления паттернов
  • Rangeify — контекст Стадии 4 (паттерны выполняются после снижения movement-операций)
  • Expander — контекст Стадии 8 (паттерны выполняются после оптимизационных действий)
  • Linearizer — контекст Стадии 16 (финальная очистка)