Индексная арифметика
Тензорные компиляторы тратят основную часть оптимизационного бюджета на индексную арифметику. Обращение tensor[i, j] с формой [H, W] превращается в i * W + j. После тайлинга, векторизации и преобразований циклов эти выражения накапливают вложенные деления и модули. Их упрощение критически важно — одна лишняя операция idiv стоит 20–40 тактов против 1 такта за эквивалентный сдвиг (приблизительно, на современном x86-64).
Эта страница документирует паттерны, упрощающие индексные выражения. Это НЕ оптимизации в традиционном смысле — это алгебра, которая обеспечивает эффективную индексацию тензоров.
Ключевая концепция — анализ диапазонов значений: Каждый UOp отслеживает минимальное (vmin) и максимальное (vmax) значения, которые он может принять во время выполнения. Эти границы вычисляются жадно при создании узла на основе границ его входов. Многие индексные паттерны используют эти границы для доказательства упрощений на этапе компиляции (например, «x всегда в [0, N)» позволяет x % N → x).
Эти паттерны выполняются на нескольких стадиях пайплайна кодогенерации:
- Стадия 4 (начальное символьное, в ходе rangeify)
- Стадия 8 (пост-оптимизационное символьное)
- Стадия 15 (снижение типа индексов через
pm_lower_index_dtype) - Стадия 16 (пост-индексное символьное)
Исходники Morok: schedule/src/symbolic/patterns.rs, schedule/src/symbolic/index_lowering.rs
Исходники Tinygrad: tinygrad/uop/divandmod.py, tinygrad/uop/symbolic.py
1. Тождество Div-Mod
Фундаментальная теорема целочисленного деления:
$$ x = \lfloor x / n \rfloor \cdot n + (x \bmod n) $$
Пять вариантов эксплуатируют это тождество в наборе паттернов:
| # | Паттерн | Условие | Название |
|---|---|---|---|
| 1 | x%n + (x//n)*n -> x | -- | Базовое тождество |
| 2 | ((x//a) % c) + (x//b)*c -> x//a | a*c == b | Составной делитель |
| 3 | (x%c1)*c2 + (x//c1)*c3 -> x*c2 | c1*c2 == c3 | Масштабированное |
| 4 | y + (x%n) + (x//n)*n -> y + x | -- | Трёхтермовое |
| 5 | (a//c1 + c2) // c3 -> (a + c1*c2) // (c1*c3) | c1>0, c3>0 | Вложенное деление |
Доказательство #1. По алгоритму деления для целых x и n > 0 существуют единственные целые q и r такие, что x = q*n + r, где 0 <= r < n. По определению q = x // n и r = x % n. Подставляя: (x % n) + (x // n) * n = r + q*n = x. QED.
Почему #2–#5 — следствия.
Вариант #2 компонует два уровня деления. Поскольку b = a*c, имеем x // b = (x // a) // c. Применяя базовое тождество на внутреннем уровне: ((x//a) % c) + ((x//a) // c) * c = x // a. Но (x//a) // c = x // (a*c) = x // b, что даёт паттерн.
Вариант #3 масштабирует обе стороны базового тождества на c2. Из x = (x % c1) + (x // c1) * c1, умножая на c2: x * c2 = (x % c1) * c2 + (x // c1) * c1 * c2. Поскольку c1 * c2 = c3, получаем (x % c1) * c2 + (x // c1) * c3 = x * c2.
Вариант #4 прибавляет независимый терм y к обеим сторонам #1.
Вариант #5 выравнивает вложенное деление. Для (a // c1 + c2) // c3 умножаем c2 на внутренний делитель для получения эквивалентного одноуровневого деления: (a + c1*c2) // (c1*c3). Это выполняется когда a >= 0 и c2 >= 0 (или оба неположительные), что гарантирует сохранение семантики деления с округлением вниз.
Все пять паттернов используют проверки Arc::ptr_eq на дублирующиеся имена переменных (например, x дважды означает, что оба должны быть одним и тем же узлом hash consing).
Реализация
// From schedule/src/symbolic/patterns.rs — div_mod_recombine_dsl_patterns()
// #1: x%n + (x//n)*n -> x
Add[Mod(x, n), Mul[Idiv(x, n), n]] ~> |x| Arc::clone(x),
// #2: ((x//a) % c) + (x // b) * c -> x // a when a*c == b
Add[Mod(Idiv(x, a), c), Mul[Idiv(x, _b), c]]
=> |x, a, a_val, c_val, b_val| { /* guard: a_int * c_int == b_int */ },
// #5: (a//c1 + c2) // c3 -> (a + c1*c2) // (c1*c3)
Idiv(Add[Idiv(a, c1), _c2], _c3)
=> |a, c1, c1_val, c2_val, c3_val| { /* guard: c1>0, c3>0, same-sign */ },
2. Mod/Div на основе диапазонов
Анализ диапазонов значений (vmin/vmax) позволяет упрощения, невидимые для чисто синтаксического сопоставления паттернов. Каждый UOp хранит кэшированные границы, вычисленные при создании.
| Паттерн | Защита | Пример |
|---|---|---|
x % n -> x | 0 <= vmin(x) и vmax(x) < n | RANGE(3) % 3 -> RANGE(3) |
(a*m + b) % n -> b % n | m == n | (row*512 + col) % 512 -> col % 512 |
(a*m + b) / n -> a + b/n | m == n и 0 <= b < n | (row*512 + col) / 512 -> row |
x / n -> k | все значения попадают в корзину [k*n, (k+1)*n) | RANGE(3) / 3 -> 0 |
(x + c) // d -> x // d | max_remainder + c < d | (R*4 + 1) // 8 -> R*4 // 8 |
(x + c) // d -> (x + c%d) // d + c//d | c >= d | (x + 70) // 8 -> (x + 6) // 8 + 8 |
Первый паттерн — основная рабочая лошадка. После разбиения диапазонов RANGE(n) порождает значения в [0, n), поэтому RANGE(n) % n тривиально упрощается до RANGE(n). Одно это правило устраняет большинство модулей, создаваемых тайлингом.
Пятый паттерн (малая константа) использует жёсткую границу на максимальный остаток в диапазоне [vmin, vmax]. Если диапазон содержит менее d значений и прибавление c никогда не пересекает границу корзины, константа — мёртвый груз.
Шестой паттерн (разделение большого смещения) канонизирует смещения, превышающие делитель. Это экспонирует паттерн малой константы для следующей итерации перезаписи.
Паттерн (a*m + b) / n -> a + b/n требует 0 <= b < n. Без проверки диапазона отрицательные остатки дают некорректные частные из-за семантики округления к нулю. Реализация явно проверяет vmin(b) >= 0 && vmax(b) < n.
3. Алгоритм fold_divmod_general
Универсальный обработчик для Idiv и Mod над Index dtype. Реализует все 8 правил из divandmod.py:8-93 Tinygrad в порядке приоритета, включая рекурсивный nest_div_by_smallest_factor. Каждое правило пробуется последовательно; первое совпадение побеждает.
Точка входа: когда Idiv(x, y) или Mod(x, y) имеет dtype == Index, паттерн делегирует fold_divmod_general(op, x, y).
Правило 1 — cancel_divmod
Если весь диапазон [x_min, x_max] отображается в одно частное для всех угловых комбинаций (x, y), результат — эта константа.
Защита: y_min * y_max > 0 (знаменатель не пересекает ноль), и все четыре угловых частных x_min/y_min, x_min/y_max, x_max/y_min, x_max/y_max равны.
Что делает: Для Idiv возвращает константное частное. Для Mod возвращает x - q*y.
Пример: RANGE(3) // 3 -> 0. Значения 0, 1, 2 все делятся до 0.
Правило 2 — remove_nested_mod
(a%4 + b) % 2 -> (a + b) % 2 когда 2 | 4. Внешний модуль делит внутренний, поэтому внутренний модуль избыточен.
Защита: op == Mod, x_min >= 0, и для каждого терма, являющегося Mod(inner_x, inner_y), знаменатель y делит inner_y.
Что делает: Снимает внутренние операции Mod, чей модуль кратен внешнему, затем повторно применяет Mod.
Пример: (RANGE(8) % 4 + RANGE(2)) % 2 -> (RANGE(8) + RANGE(2)) % 2
Правило 3 — fold_binary_numerator
Когда единственный неконстантный терм имеет ровно 2 значения (vmax - vmin == 1), результат — линейная интерполяция: (y2 - y1) * (v - v_min) + y1.
Защита: Ровно один неконстантный терм после декомпозиции, и его диапазон охватывает ровно 2 значения.
Что делает: Вычисляет div/mod в обеих конечных точках и строит линейное отображение между ними. Это полностью избегает деления.
Пример: Для (v * 3 + 2) % 5 где v находится в {0, 1}:
v=0:(0 + 2) % 5 = 2v=1:(3 + 2) % 5 = 0- Результат:
(0 - 2) * (v - 0) + 2 = -2*v + 2
Правило 4 — fold_divmod_congruence
Для каждого терма f_i * v_i вычисляется ближайший остаток r_i = min(f_i % c, f_i % c - c) по абсолютному значению. Если сумма остатков остаётся в пределах одной корзины деления c, mod/div упрощается. Это оптимизация модулярной арифметики.
Защита: x_min >= 0, константный знаменатель c > 0, и rem_min // c == rem_max // c (все значения суммы остатков попадают в одну корзину).
Что делает: Заменяет каждый множитель его остатком по модулю c. Для Mod возвращает сумму остатков (с поправкой на смещение корзины). Для Idiv возвращает сумму коэффициентов частных.
Пример: (r*8 + v) % 7 -> (r + v) % 7, поскольку 8 = 1 (mod 7), и остаток 8 равен 1.
Правило 5 — gcd_with_remainder
Вычисляет символьный GCD всех аддитивных термов и знаменателя. Если GCD > 1, выносит его: (g*a + g*b) // (g*c) -> (a + b) // c (с масштабированием остатка для Mod).
Защита: x_min >= 0, константный знаменатель, GCD > 1, и приведённый числитель имеет vmin >= 0.
Что делает: Делит и термы числителя, и знаменатель на их GCD, рекурсивно позволяя более простым паттернам сработать.
Пример: (6*a + 4*b) // 8 с GCD(6, 4, 8) = 2 -> (3*a + 2*b) // 4
Правило 6 — divide_by_gcd
Версия Правила 5 для переменного знаменателя. Вычисляет GCD(все_термы..., y), включая и числитель, и знаменатель, затем делит обе стороны. В отличие от Правила 5, работает когда знаменатель не является константой.
Защита: GCD нетривиален (не 1), и оба x и y точно делятся на GCD.
Пример: (4*a) // (2*b) -> (2*a) // b
Правило 7 — factor_remainder
Крайний вариант. Разбивает термы на точно делимые (частное) и остаток.
Защита: x_min >= 0 и y_min >= 0, и хотя бы один терм делит y нацело.
Что делает: Для Idiv: quo_sum + rem // y. Для Mod: rem % y (с приведением коэффициентов для константного y).
Пример: (8*a + 3*b) // 8 -> a + (3*b) // 8
Правило 8 — nest_div_by_smallest_factor
Рекурсивная декомпозиция для константных делителей. Находит наименьший общий множитель между делителем и коэффициентом любого терма, делит оба на него, затем рекурсивно вызывается.
Защита: x_min >= 0, константный y > 1, и хотя бы один неконстантный терм имеет множитель f > 1, где y % f == 0.
Что делает: Выбирает div = min(|f|) среди подходящих множителей, перезаписывает x // y как (x // div) // (y / div). Каждый шаг уменьшает y, сходясь к правилам 1–7.
Пример: (6*a + 4*b) // 12 → ((6*a + 4*b) // 2) // 6 → (3*a + 2*b) // 6 → (3*a + 2*b) // 6 (затем правило 7 завершает).
Tinygrad: divandmod.py:62-67. Morok: nest_div_by_smallest_factor в fold_divmod_general.
Правила 5–8 требуют неотрицательных числителей (x_min >= 0). Деление с округлением вниз с отрицательными операндами имеет иную семантику округления (к отрицательной бесконечности в Python/Tinygrad, к нулю в аппаратуре). Реализация возвращает None для отрицательных диапазонов, позволяя последующим проходам обработать выражение.
4. Продвинутые паттерны деления
Самостоятельные паттерны вне fold_divmod_general, обрабатывающие дополнительные случаи:
| Паттерн | Защита | Источник |
|---|---|---|
(a // b) // c -> a // (b*c) | b != 0, c != 0 | advanced_division_dsl_patterns |
expr // divisor -> точное частное | expr точно делится | advanced_division_dsl_patterns |
(a + b) % c приведение коэффициентов | a или b имеет множитель, делимый на c | advanced_division_dsl_patterns |
(a + b) // c -> a//c + b//c | оба делятся нацело | advanced_division_dsl_patterns |
(a - b) // c -> a//c - b//c | оба делятся нацело | advanced_division_dsl_patterns |
c * (a + b) -> c*a + c*b | c — константа | advanced_division_dsl_patterns |
Сворачивание вложенного деления (a // b) // c -> a // (b*c) особенно важно после тайлинга, где разбиение диапазона на внешнюю и внутреннюю компоненты создаёт два уровня деления, которые должны сворачиваться в один.
Паттерн точного деления использует divides(), который проверяет, делится ли константный множитель каждого аддитивного терма на делитель. При успехе Idiv полностью устраняется — инструкция деления не генерируется.
Паттерн приведения коэффициентов преобразует (r*8 + v) % 7 -> (r*1 + v) % 7 = (r + v) % 7, приводя каждый множитель по модулю делителя. Срабатывает, когда ни один множитель не является точным кратным модуля, но остатки меньше.
5. Снижение типа Index (3-фазный каскад)
Tinygrad: ops.py:1291-1313. Morok: schedule/src/symbolic/index_lowering.rs.
Абстрактный тип Index не имеет фиксированной разрядности — он представляет «целое число той разрядности, которая необходима для данного индекса». Проход снижения преобразует Index в конкретный i32 или i64 на основе границ значений.
Фаза 1 — Создание обёрток (листовые узлы)
Листовые узлы с dtype Index заменяются конкретным эквивалентом, обёрнутым в приведение обратно к Index:
| Вход | Выход |
|---|---|
CONST(Index) | CONST(concrete).cast(Index) |
DEFINE_VAR(Index) | DEFINE_VAR(concrete).cast(Index) |
VCONST(Vector<Index, N>) | VCONST(Vector<concrete, N>).cast(Vector<Index, N>) |
Фаза 2 — Обработка обёрнутых значений вверх
Бинарные операции, управление потоком и структурные узлы распространяют конкретный тип через обёртки .cast(Index):
| Вход | Выход |
|---|---|
Binary(x.cast(Index), y.cast(Index)) | Binary(x.cast(dt), y.cast(dt)).cast(result_dtype) |
WHERE(cond, x.cast(Index), y.cast(Index)) | WHERE(cond, x.cast(dt), y.cast(dt)).cast(Index) |
RANGE(end.cast(Index)) | RANGE(end, end.dtype).cast(Index) |
SPECIAL(end.cast(Index)) | SPECIAL(end, i32).cast(Index) |
VECTORIZE(e0.cast(Index), ...) | VECTORIZE(e0.cast(dt), ...).cast(Vector<Index, N>) |
BIND(var.cast(Index), val.cast(Index)) | var.cast(dt).bind(val.cast(dt)).cast(Index) |
dt вычисляется как least_upper_dtype(select_dtype(result), x.dtype, y.dtype) — наибольший тип, необходимый для любого операнда или результата.
Фаза 3 — Снятие обёрток на терминалах
Терминальные узлы потребляют индекс и отбрасывают обёртку Index:
| Вход | Выход |
|---|---|
INDEX(buf, idx.cast(Index)) | INDEX(buf, idx) |
INDEX(buf, WHERE(cond, idx, Invalid)) | INDEX(buf, idx, gate=cond) |
SINK(sources с .cast(Index)) | SINK(развёрнутые sources) |
END(computation.cast(Index)) | END(развёрнутое computation) |
Преобразование WHERE(cond, idx, Invalid) -> gate=cond значительно: оно извлекает условия валидности из индексного выражения в поле gate узла INDEX, которое backend-ы кодогенерации используют для генерации предикатных загрузок.
select_dtype()
Возвращает i32, если границы значений UOp помещаются в [-2^31, 2^31 - 1], иначе i64. Большинство тензорных индексов помещаются в i32 — даже плоский индекс тензора с 2 миллиардами элементов помещается. Путь через i64 существует для очень больших тензоров или накопленных смещений.
6. Коммутативная канонизация
// For Index dtype ONLY:
op(a, b) -> op(b, a) when b.id < a.id
Это обеспечивает детерминированный порядок операндов для коммутативных операций на основе уникального ID узла UOp. Применяется к: Add, Mul, Max, Eq, Ne, And, Or, Xor.
Почему только Index: Без канонизации R1*8000 + R2*16 и R2*16 + R1*8000 — разные узлы после hash consing, что нарушает группировку в expand_vector_index. Экспандеру нужно идентифицировать идентичные индексные паттерны между полосами вектора, и неканоническое упорядочение это ломает.
Почему НЕ применяется к не-Index типам: Применение канонизации к float/int арифметике переупорядочит элементы VECTORIZE и нарушит слияние векторной математики в последующих проходах. Tinygrad делает тот же выбор (symbolic.py:178-182).
Канонизация взаимодействует с итерацией фиксированной точки движка перезаписи. Если два паттерна расходятся в порядке операндов (один канонизирует, другой порождает неканонический порядок), движок может осциллировать. Все паттерны, порождающие индексы, должны соблюдать канонический порядок, иначе сработает предел безопасности в 1000 итераций.
Рабочий пример
Рассмотрим tensor[i, j] с формой [4, 8], доступ через плоскую итерацию по 32 элементам.
Начальное состояние
Диапазон R0 итерирует 0..32 (плоский индекс). Паттерн доступа декомпозируется в:
row = R0 // 8 (which of the 4 rows)
col = R0 % 8 (which of the 8 columns)
addr = row * 8 + col = (R0 // 8) * 8 + (R0 % 8)
По тождеству div-mod (#1), (R0 // 8) * 8 + (R0 % 8) = R0. Адрес — просто плоский индекс, деление не нужно.
После тайлинга (UPCAST на 4)
Разбиение диапазона декомпозирует R0 в R1 * 4 + R2, где R1 в [0, 8) и R2 в [0, 4):
row = (R1*4 + R2) // 8
col = (R1*4 + R2) % 8
Упрощение row: Выражение (R1*4 + R2) // 8 поступает в fold_divmod_general.
Правило 4 (конгруэнтность) срабатывает: множитель 4 имеет остаток 4 % 8 = 4, а R2 имеет остаток 1 % 8 = 1. Сумма остатков 4*R1 + R2 с диапазоном [0, 31]. Поскольку 0 // 8 != 31 // 8, Правило 4 не сворачивает до константы. Вместо этого срабатывает Правило 7 (factor remainder): 4 не делит 8 нацело, но выражение может быть декомпозировано. Поскольку ни один терм не делит 8 нацело, мы проходим к паттерну на основе диапазонов (a*m + b) / n с m = 4, n = 8 — он не совпадает (m != n).
Выражение остаётся как (R1*4 + R2) // 8. В сгенерированном коде, если R2 векторизован (UPCAST), backend генерирует это как одно деление 4-широкого вектора.
Однако, если мы далее разобьём R1 на R3 * 2 + R4 (где R3 в [0, 4), R4 в [0, 2)):
row = (R3*2*4 + R4*4 + R2) // 8
= (R3*8 + R4*4 + R2) // 8
Теперь паттерн на основе диапазонов (a*m + b) / n срабатывает с m = n = 8:
a = R3,b = R4*4 + R2vmin(b) = 0,vmax(b) = 1*4 + 3 = 7 < 8- Результат:
R3 + (R4*4 + R2) // 8
И (R4*4 + R2) // 8: vmax = 1*4 + 3 = 7, vmin = 0, поэтому 0 // 8 = 7 // 8 = 0. Срабатывает правило cancel_divmod:
- Результат:
R3 + 0 = R3
Упрощение col: (R3*8 + R4*4 + R2) % 8
Паттерн на основе диапазонов (a*m + b) % n срабатывает с m = n = 8:
(R3*8 + R4*4 + R2) % 8->(R4*4 + R2) % 8
Затем vmin(R4*4 + R2) = 0, vmax(R4*4 + R2) = 7 < 8, поэтому x % n -> x:
- Результат:
R4*4 + R2
Итоговое дерево
Before (after tiling, before simplification):
STORE(
INDEX(buf, (R3*8 + R4*4 + R2) // 8 * 8 + (R3*8 + R4*4 + R2) % 8),
value)
After index arithmetic:
STORE(
INDEX(buf, R3*8 + R4*4 + R2),
value)
Всё вычисление адреса сворачивается обратно в линейное выражение — ноль делений, ноль модулей. Паттерны доказали, что тайлированный индекс эквивалентен плоскому индексу, исключительно через алгебраическую перезапись.