Перейти к основному содержимому

Индексная арифметика

Тензорные компиляторы тратят основную часть оптимизационного бюджета на индексную арифметику. Обращение tensor[i, j] с формой [H, W] превращается в i * W + j. После тайлинга, векторизации и преобразований циклов эти выражения накапливают вложенные деления и модули. Их упрощение критически важно — одна лишняя операция idiv стоит 20–40 тактов против 1 такта за эквивалентный сдвиг (приблизительно, на современном x86-64).

Эта страница документирует паттерны, упрощающие индексные выражения. Это НЕ оптимизации в традиционном смысле — это алгебра, которая обеспечивает эффективную индексацию тензоров.

Ключевая концепция — анализ диапазонов значений: Каждый UOp отслеживает минимальное (vmin) и максимальное (vmax) значения, которые он может принять во время выполнения. Эти границы вычисляются жадно при создании узла на основе границ его входов. Многие индексные паттерны используют эти границы для доказательства упрощений на этапе компиляции (например, «x всегда в [0, N)» позволяет x % Nx).

Эти паттерны выполняются на нескольких стадиях пайплайна кодогенерации:

  • Стадия 4 (начальное символьное, в ходе rangeify)
  • Стадия 8 (пост-оптимизационное символьное)
  • Стадия 15 (снижение типа индексов через pm_lower_index_dtype)
  • Стадия 16 (пост-индексное символьное)

Исходники Morok: schedule/src/symbolic/patterns.rs, schedule/src/symbolic/index_lowering.rs

Исходники Tinygrad: tinygrad/uop/divandmod.py, tinygrad/uop/symbolic.py


1. Тождество Div-Mod

Фундаментальная теорема целочисленного деления:

$$ x = \lfloor x / n \rfloor \cdot n + (x \bmod n) $$

Пять вариантов эксплуатируют это тождество в наборе паттернов:

#ПаттернУсловиеНазвание
1x%n + (x//n)*n -> x--Базовое тождество
2((x//a) % c) + (x//b)*c -> x//aa*c == bСоставной делитель
3(x%c1)*c2 + (x//c1)*c3 -> x*c2c1*c2 == c3Масштабированное
4y + (x%n) + (x//n)*n -> y + x--Трёхтермовое
5(a//c1 + c2) // c3 -> (a + c1*c2) // (c1*c3)c1>0, c3>0Вложенное деление

Доказательство #1. По алгоритму деления для целых x и n > 0 существуют единственные целые q и r такие, что x = q*n + r, где 0 <= r < n. По определению q = x // n и r = x % n. Подставляя: (x % n) + (x // n) * n = r + q*n = x. QED.

Почему #2–#5 — следствия.

Вариант #2 компонует два уровня деления. Поскольку b = a*c, имеем x // b = (x // a) // c. Применяя базовое тождество на внутреннем уровне: ((x//a) % c) + ((x//a) // c) * c = x // a. Но (x//a) // c = x // (a*c) = x // b, что даёт паттерн.

Вариант #3 масштабирует обе стороны базового тождества на c2. Из x = (x % c1) + (x // c1) * c1, умножая на c2: x * c2 = (x % c1) * c2 + (x // c1) * c1 * c2. Поскольку c1 * c2 = c3, получаем (x % c1) * c2 + (x // c1) * c3 = x * c2.

Вариант #4 прибавляет независимый терм y к обеим сторонам #1.

Вариант #5 выравнивает вложенное деление. Для (a // c1 + c2) // c3 умножаем c2 на внутренний делитель для получения эквивалентного одноуровневого деления: (a + c1*c2) // (c1*c3). Это выполняется когда a >= 0 и c2 >= 0 (или оба неположительные), что гарантирует сохранение семантики деления с округлением вниз.

Все пять паттернов используют проверки Arc::ptr_eq на дублирующиеся имена переменных (например, x дважды означает, что оба должны быть одним и тем же узлом hash consing).

Реализация

// From schedule/src/symbolic/patterns.rs — div_mod_recombine_dsl_patterns()

// #1: x%n + (x//n)*n -> x
Add[Mod(x, n), Mul[Idiv(x, n), n]] ~> |x| Arc::clone(x),

// #2: ((x//a) % c) + (x // b) * c -> x // a when a*c == b
Add[Mod(Idiv(x, a), c), Mul[Idiv(x, _b), c]]
=> |x, a, a_val, c_val, b_val| { /* guard: a_int * c_int == b_int */ },

// #5: (a//c1 + c2) // c3 -> (a + c1*c2) // (c1*c3)
Idiv(Add[Idiv(a, c1), _c2], _c3)
=> |a, c1, c1_val, c2_val, c3_val| { /* guard: c1>0, c3>0, same-sign */ },

2. Mod/Div на основе диапазонов

Анализ диапазонов значений (vmin/vmax) позволяет упрощения, невидимые для чисто синтаксического сопоставления паттернов. Каждый UOp хранит кэшированные границы, вычисленные при создании.

ПаттернЗащитаПример
x % n -> x0 <= vmin(x) и vmax(x) < nRANGE(3) % 3 -> RANGE(3)
(a*m + b) % n -> b % nm == n(row*512 + col) % 512 -> col % 512
(a*m + b) / n -> a + b/nm == n и 0 <= b < n(row*512 + col) / 512 -> row
x / n -> kвсе значения попадают в корзину [k*n, (k+1)*n)RANGE(3) / 3 -> 0
(x + c) // d -> x // dmax_remainder + c < d(R*4 + 1) // 8 -> R*4 // 8
(x + c) // d -> (x + c%d) // d + c//dc >= d(x + 70) // 8 -> (x + 6) // 8 + 8

Первый паттерн — основная рабочая лошадка. После разбиения диапазонов RANGE(n) порождает значения в [0, n), поэтому RANGE(n) % n тривиально упрощается до RANGE(n). Одно это правило устраняет большинство модулей, создаваемых тайлингом.

Пятый паттерн (малая константа) использует жёсткую границу на максимальный остаток в диапазоне [vmin, vmax]. Если диапазон содержит менее d значений и прибавление c никогда не пересекает границу корзины, константа — мёртвый груз.

Шестой паттерн (разделение большого смещения) канонизирует смещения, превышающие делитель. Это экспонирует паттерн малой константы для следующей итерации перезаписи.

предупреждение

Паттерн (a*m + b) / n -> a + b/n требует 0 <= b < n. Без проверки диапазона отрицательные остатки дают некорректные частные из-за семантики округления к нулю. Реализация явно проверяет vmin(b) >= 0 && vmax(b) < n.


3. Алгоритм fold_divmod_general

Универсальный обработчик для Idiv и Mod над Index dtype. Реализует все 8 правил из divandmod.py:8-93 Tinygrad в порядке приоритета, включая рекурсивный nest_div_by_smallest_factor. Каждое правило пробуется последовательно; первое совпадение побеждает.

Точка входа: когда Idiv(x, y) или Mod(x, y) имеет dtype == Index, паттерн делегирует fold_divmod_general(op, x, y).

Правило 1 — cancel_divmod

Если весь диапазон [x_min, x_max] отображается в одно частное для всех угловых комбинаций (x, y), результат — эта константа.

Защита: y_min * y_max > 0 (знаменатель не пересекает ноль), и все четыре угловых частных x_min/y_min, x_min/y_max, x_max/y_min, x_max/y_max равны.

Что делает: Для Idiv возвращает константное частное. Для Mod возвращает x - q*y.

Пример: RANGE(3) // 3 -> 0. Значения 0, 1, 2 все делятся до 0.

Правило 2 — remove_nested_mod

(a%4 + b) % 2 -> (a + b) % 2 когда 2 | 4. Внешний модуль делит внутренний, поэтому внутренний модуль избыточен.

Защита: op == Mod, x_min >= 0, и для каждого терма, являющегося Mod(inner_x, inner_y), знаменатель y делит inner_y.

Что делает: Снимает внутренние операции Mod, чей модуль кратен внешнему, затем повторно применяет Mod.

Пример: (RANGE(8) % 4 + RANGE(2)) % 2 -> (RANGE(8) + RANGE(2)) % 2

Правило 3 — fold_binary_numerator

Когда единственный неконстантный терм имеет ровно 2 значения (vmax - vmin == 1), результат — линейная интерполяция: (y2 - y1) * (v - v_min) + y1.

Защита: Ровно один неконстантный терм после декомпозиции, и его диапазон охватывает ровно 2 значения.

Что делает: Вычисляет div/mod в обеих конечных точках и строит линейное отображение между ними. Это полностью избегает деления.

Пример: Для (v * 3 + 2) % 5 где v находится в {0, 1}:

  • v=0: (0 + 2) % 5 = 2
  • v=1: (3 + 2) % 5 = 0
  • Результат: (0 - 2) * (v - 0) + 2 = -2*v + 2

Правило 4 — fold_divmod_congruence

Для каждого терма f_i * v_i вычисляется ближайший остаток r_i = min(f_i % c, f_i % c - c) по абсолютному значению. Если сумма остатков остаётся в пределах одной корзины деления c, mod/div упрощается. Это оптимизация модулярной арифметики.

Защита: x_min >= 0, константный знаменатель c > 0, и rem_min // c == rem_max // c (все значения суммы остатков попадают в одну корзину).

Что делает: Заменяет каждый множитель его остатком по модулю c. Для Mod возвращает сумму остатков (с поправкой на смещение корзины). Для Idiv возвращает сумму коэффициентов частных.

Пример: (r*8 + v) % 7 -> (r + v) % 7, поскольку 8 = 1 (mod 7), и остаток 8 равен 1.

Правило 5 — gcd_with_remainder

Вычисляет символьный GCD всех аддитивных термов и знаменателя. Если GCD > 1, выносит его: (g*a + g*b) // (g*c) -> (a + b) // c (с масштабированием остатка для Mod).

Защита: x_min >= 0, константный знаменатель, GCD > 1, и приведённый числитель имеет vmin >= 0.

Что делает: Делит и термы числителя, и знаменатель на их GCD, рекурсивно позволяя более простым паттернам сработать.

Пример: (6*a + 4*b) // 8 с GCD(6, 4, 8) = 2 -> (3*a + 2*b) // 4

Правило 6 — divide_by_gcd

Версия Правила 5 для переменного знаменателя. Вычисляет GCD(все_термы..., y), включая и числитель, и знаменатель, затем делит обе стороны. В отличие от Правила 5, работает когда знаменатель не является константой.

Защита: GCD нетривиален (не 1), и оба x и y точно делятся на GCD.

Пример: (4*a) // (2*b) -> (2*a) // b

Правило 7 — factor_remainder

Крайний вариант. Разбивает термы на точно делимые (частное) и остаток.

Защита: x_min >= 0 и y_min >= 0, и хотя бы один терм делит y нацело.

Что делает: Для Idiv: quo_sum + rem // y. Для Mod: rem % y (с приведением коэффициентов для константного y).

Пример: (8*a + 3*b) // 8 -> a + (3*b) // 8

Правило 8 — nest_div_by_smallest_factor

Рекурсивная декомпозиция для константных делителей. Находит наименьший общий множитель между делителем и коэффициентом любого терма, делит оба на него, затем рекурсивно вызывается.

Защита: x_min >= 0, константный y > 1, и хотя бы один неконстантный терм имеет множитель f > 1, где y % f == 0.

Что делает: Выбирает div = min(|f|) среди подходящих множителей, перезаписывает x // y как (x // div) // (y / div). Каждый шаг уменьшает y, сходясь к правилам 1–7.

Пример: (6*a + 4*b) // 12((6*a + 4*b) // 2) // 6(3*a + 2*b) // 6(3*a + 2*b) // 6 (затем правило 7 завершает).

Tinygrad: divandmod.py:62-67. Morok: nest_div_by_smallest_factor в fold_divmod_general.

предупреждение

Правила 5–8 требуют неотрицательных числителей (x_min >= 0). Деление с округлением вниз с отрицательными операндами имеет иную семантику округления (к отрицательной бесконечности в Python/Tinygrad, к нулю в аппаратуре). Реализация возвращает None для отрицательных диапазонов, позволяя последующим проходам обработать выражение.


4. Продвинутые паттерны деления

Самостоятельные паттерны вне fold_divmod_general, обрабатывающие дополнительные случаи:

ПаттернЗащитаИсточник
(a // b) // c -> a // (b*c)b != 0, c != 0advanced_division_dsl_patterns
expr // divisor -> точное частноеexpr точно делитсяadvanced_division_dsl_patterns
(a + b) % c приведение коэффициентовa или b имеет множитель, делимый на cadvanced_division_dsl_patterns
(a + b) // c -> a//c + b//cоба делятся нацелоadvanced_division_dsl_patterns
(a - b) // c -> a//c - b//cоба делятся нацелоadvanced_division_dsl_patterns
c * (a + b) -> c*a + c*bc — константаadvanced_division_dsl_patterns

Сворачивание вложенного деления (a // b) // c -> a // (b*c) особенно важно после тайлинга, где разбиение диапазона на внешнюю и внутреннюю компоненты создаёт два уровня деления, которые должны сворачиваться в один.

Паттерн точного деления использует divides(), который проверяет, делится ли константный множитель каждого аддитивного терма на делитель. При успехе Idiv полностью устраняется — инструкция деления не генерируется.

Паттерн приведения коэффициентов преобразует (r*8 + v) % 7 -> (r*1 + v) % 7 = (r + v) % 7, приводя каждый множитель по модулю делителя. Срабатывает, когда ни один множитель не является точным кратным модуля, но остатки меньше.


5. Снижение типа Index (3-фазный каскад)

Tinygrad: ops.py:1291-1313. Morok: schedule/src/symbolic/index_lowering.rs.

Абстрактный тип Index не имеет фиксированной разрядности — он представляет «целое число той разрядности, которая необходима для данного индекса». Проход снижения преобразует Index в конкретный i32 или i64 на основе границ значений.

Фаза 1 — Создание обёрток (листовые узлы)

Листовые узлы с dtype Index заменяются конкретным эквивалентом, обёрнутым в приведение обратно к Index:

ВходВыход
CONST(Index)CONST(concrete).cast(Index)
DEFINE_VAR(Index)DEFINE_VAR(concrete).cast(Index)
VCONST(Vector<Index, N>)VCONST(Vector<concrete, N>).cast(Vector<Index, N>)

Фаза 2 — Обработка обёрнутых значений вверх

Бинарные операции, управление потоком и структурные узлы распространяют конкретный тип через обёртки .cast(Index):

ВходВыход
Binary(x.cast(Index), y.cast(Index))Binary(x.cast(dt), y.cast(dt)).cast(result_dtype)
WHERE(cond, x.cast(Index), y.cast(Index))WHERE(cond, x.cast(dt), y.cast(dt)).cast(Index)
RANGE(end.cast(Index))RANGE(end, end.dtype).cast(Index)
SPECIAL(end.cast(Index))SPECIAL(end, i32).cast(Index)
VECTORIZE(e0.cast(Index), ...)VECTORIZE(e0.cast(dt), ...).cast(Vector<Index, N>)
BIND(var.cast(Index), val.cast(Index))var.cast(dt).bind(val.cast(dt)).cast(Index)

dt вычисляется как least_upper_dtype(select_dtype(result), x.dtype, y.dtype) — наибольший тип, необходимый для любого операнда или результата.

Фаза 3 — Снятие обёрток на терминалах

Терминальные узлы потребляют индекс и отбрасывают обёртку Index:

ВходВыход
INDEX(buf, idx.cast(Index))INDEX(buf, idx)
INDEX(buf, WHERE(cond, idx, Invalid))INDEX(buf, idx, gate=cond)
SINK(sources с .cast(Index))SINK(развёрнутые sources)
END(computation.cast(Index))END(развёрнутое computation)

Преобразование WHERE(cond, idx, Invalid) -> gate=cond значительно: оно извлекает условия валидности из индексного выражения в поле gate узла INDEX, которое backend-ы кодогенерации используют для генерации предикатных загрузок.

select_dtype()

Возвращает i32, если границы значений UOp помещаются в [-2^31, 2^31 - 1], иначе i64. Большинство тензорных индексов помещаются в i32 — даже плоский индекс тензора с 2 миллиардами элементов помещается. Путь через i64 существует для очень больших тензоров или накопленных смещений.


6. Коммутативная канонизация

// For Index dtype ONLY:
op(a, b) -> op(b, a) when b.id < a.id

Это обеспечивает детерминированный порядок операндов для коммутативных операций на основе уникального ID узла UOp. Применяется к: Add, Mul, Max, Eq, Ne, And, Or, Xor.

Почему только Index: Без канонизации R1*8000 + R2*16 и R2*16 + R1*8000 — разные узлы после hash consing, что нарушает группировку в expand_vector_index. Экспандеру нужно идентифицировать идентичные индексные паттерны между полосами вектора, и неканоническое упорядочение это ломает.

Почему НЕ применяется к не-Index типам: Применение канонизации к float/int арифметике переупорядочит элементы VECTORIZE и нарушит слияние векторной математики в последующих проходах. Tinygrad делает тот же выбор (symbolic.py:178-182).

предупреждение

Канонизация взаимодействует с итерацией фиксированной точки движка перезаписи. Если два паттерна расходятся в порядке операндов (один канонизирует, другой порождает неканонический порядок), движок может осциллировать. Все паттерны, порождающие индексы, должны соблюдать канонический порядок, иначе сработает предел безопасности в 1000 итераций.


Рабочий пример

Рассмотрим tensor[i, j] с формой [4, 8], доступ через плоскую итерацию по 32 элементам.

Начальное состояние

Диапазон R0 итерирует 0..32 (плоский индекс). Паттерн доступа декомпозируется в:

row = R0 // 8 (which of the 4 rows)
col = R0 % 8 (which of the 8 columns)
addr = row * 8 + col = (R0 // 8) * 8 + (R0 % 8)

По тождеству div-mod (#1), (R0 // 8) * 8 + (R0 % 8) = R0. Адрес — просто плоский индекс, деление не нужно.

После тайлинга (UPCAST на 4)

Разбиение диапазона декомпозирует R0 в R1 * 4 + R2, где R1 в [0, 8) и R2 в [0, 4):

row = (R1*4 + R2) // 8
col = (R1*4 + R2) % 8

Упрощение row: Выражение (R1*4 + R2) // 8 поступает в fold_divmod_general.

Правило 4 (конгруэнтность) срабатывает: множитель 4 имеет остаток 4 % 8 = 4, а R2 имеет остаток 1 % 8 = 1. Сумма остатков 4*R1 + R2 с диапазоном [0, 31]. Поскольку 0 // 8 != 31 // 8, Правило 4 не сворачивает до константы. Вместо этого срабатывает Правило 7 (factor remainder): 4 не делит 8 нацело, но выражение может быть декомпозировано. Поскольку ни один терм не делит 8 нацело, мы проходим к паттерну на основе диапазонов (a*m + b) / n с m = 4, n = 8 — он не совпадает (m != n).

Выражение остаётся как (R1*4 + R2) // 8. В сгенерированном коде, если R2 векторизован (UPCAST), backend генерирует это как одно деление 4-широкого вектора.

Однако, если мы далее разобьём R1 на R3 * 2 + R4 (где R3 в [0, 4), R4 в [0, 2)):

row = (R3*2*4 + R4*4 + R2) // 8
= (R3*8 + R4*4 + R2) // 8

Теперь паттерн на основе диапазонов (a*m + b) / n срабатывает с m = n = 8:

  • a = R3, b = R4*4 + R2
  • vmin(b) = 0, vmax(b) = 1*4 + 3 = 7 < 8
  • Результат: R3 + (R4*4 + R2) // 8

И (R4*4 + R2) // 8: vmax = 1*4 + 3 = 7, vmin = 0, поэтому 0 // 8 = 7 // 8 = 0. Срабатывает правило cancel_divmod:

  • Результат: R3 + 0 = R3

Упрощение col: (R3*8 + R4*4 + R2) % 8

Паттерн на основе диапазонов (a*m + b) % n срабатывает с m = n = 8:

  • (R3*8 + R4*4 + R2) % 8 -> (R4*4 + R2) % 8

Затем vmin(R4*4 + R2) = 0, vmax(R4*4 + R2) = 7 < 8, поэтому x % n -> x:

  • Результат: R4*4 + R2

Итоговое дерево

Before (after tiling, before simplification):
STORE(
INDEX(buf, (R3*8 + R4*4 + R2) // 8 * 8 + (R3*8 + R4*4 + R2) % 8),
value)

After index arithmetic:
STORE(
INDEX(buf, R3*8 + R4*4 + R2),
value)

Всё вычисление адреса сворачивается обратно в линейное выражение — ноль делений, ноль модулей. Паттерны доказали, что тайлированный индекс эквивалентен плоскому индексу, исключительно через алгебраическую перезапись.