diff --git a/test/null/test_winograd.py b/test/null/test_winograd.py index 8e3402e1f2dd1..eb165ef9c4d85 100644 --- a/test/null/test_winograd.py +++ b/test/null/test_winograd.py @@ -20,10 +20,20 @@ def test_forward_kernels(self): out = Tensor.conv2d(x,w) self.assertEqual(len(out.schedule_linear().src), 2) + def test_cin1_and_depthwise_trigger_wino(self): + x1, w1 = Tensor.empty(1,1,9,9).realize(), Tensor.empty(4,1,3,3).realize() + self.assertEqual(len(Tensor.conv2d(x1, w1, padding=1).schedule_linear().src), 2) + xd, wd = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,1,3,3).realize() + self.assertEqual(len(Tensor.conv2d(xd, wd, padding=1, groups=4).schedule_linear().src), 2) + def test_backward_kernels(self): + # NOTE: out.mean() collapses the conv to a constant scalar, so its backward graph has no real + # forward-conv structure left to rewrite. Use a real loss so the forward conv survives in the + # backward graph and pm_wino fires on it without needing a dedicated gradient hook. x,w = Tensor.empty(1,4,9,9,requires_grad=True).realize(), Tensor.empty(4,4,3,3,requires_grad=True).realize() + y = Tensor.empty(1,4,9,9).realize() out = Tensor.conv2d(x,w, padding=1) - out.mean().backward() + ((out - y)**2).sum().backward() backward_schedule = x.grad.schedule_linear(w.grad) self.assertEqual(len(backward_schedule.src), 4) diff --git a/test/unit/test_winograd.py b/test/unit/test_winograd.py index 7869d7896dc12..451a62b110ee5 100644 --- a/test/unit/test_winograd.py +++ b/test/unit/test_winograd.py @@ -1,6 +1,6 @@ import unittest, sys import numpy as np -from tinygrad import Tensor, GlobalCounters, Context, nn +from tinygrad import Tensor, GlobalCounters, Context, nn, dtypes from tinygrad.helpers import WINO @unittest.skipIf(sys.platform.startswith("win"), "flaky on Windows") @@ -34,5 +34,110 @@ def test_padded_conv2d(self): with Context(WINO=1): result = Tensor.conv2d(x,w,padding=1).realize() np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-4) + def test_handwritten_conv_triggers_wino(self): + # 3^n conv built directly from _pool + multiply + sum (i.e. not via Tensor.conv2d). + # Master gates wino inside Tensor.conv2d, so any of these handwritten variants would NOT + # fire there. The rewrite rule version fires on the produced UOp graph regardless of how + # it was built — exercising the generality of the affine-detection approach. + def manual_conv(x, w, swap_mul=False, downstream=lambda r: r): + bs, cin, cout, HW = x.shape[0], x.shape[1], w.shape[0], w.shape[2:] + pooled = x._pool(HW, 1, 1) + oyx = pooled.shape[2:-len(HW)] + pooled = pooled.reshape(bs, 1, cin, 1, *oyx, *HW).expand(bs, 1, cin, cout, *oyx, *HW)\ + .permute(0,1,3,*[4+i for i in range(len(oyx))],2,*[4+len(oyx)+i for i in range(len(HW))]) + reshaped_w = w.reshape(1, 1, cout, *[1]*len(oyx), cin, *HW) + mul = (reshaped_w * pooled) if swap_mul else (pooled * reshaped_w) + return downstream(mul.sum([-1-i for i in range(1+len(oyx))], keepdim=True).reshape(bs, cout, *oyx)) + cases = [ + ("1D", Tensor.rand(1,4,9).realize(), Tensor.rand(4,4,3).realize(), {}), + ("2D", Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize(), {}), + ("3D", Tensor.rand(1,4,9,9,9).realize(), Tensor.rand(4,4,3,3,3).realize(), {}), + ("swapped MUL operands",Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize(), {"swap_mul": True}), + ("downstream relu", Tensor.rand(1,4,9,9).realize(), Tensor.rand(4,4,3,3).realize(), {"downstream": lambda r: r.relu()}), + ] + for name, x, w, kw in cases: + with self.subTest(name=name): + with Context(WINO=0): base = len(manual_conv(x, w, **kw).schedule_linear().src) + with Context(WINO=1): wino = len(manual_conv(x, w, **kw).schedule_linear().src) + self.assertGreater(wino, base, f"{name}: wino did not fire (base={base}, wino={wino})") + + def test_mixed_dtype_accumulate_triggers_wino(self): + x = Tensor.rand(1,4,9,9).cast(dtypes.half).realize() + w = Tensor.rand(4,4,3,3).cast(dtypes.half).realize() + with Context(WINO=0): + expected = Tensor.conv2d(x, w, padding=1, dtype=dtypes.float32).realize() + base = len(Tensor.conv2d(x, w, padding=1, dtype=dtypes.float32).schedule_linear().src) + with Context(WINO=1): + result = Tensor.conv2d(x, w, padding=1, dtype=dtypes.float32).realize() + wino = len(Tensor.conv2d(x, w, padding=1, dtype=dtypes.float32).schedule_linear().src) + self.assertGreater(wino, base) + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-2, rtol=1e-2) + + def test_5x5_conv_triggers_wino(self): + x = Tensor.rand(1,4,16,16).realize() + w = Tensor.rand(8,4,5,5).realize() + with Context(WINO=0): + expected = Tensor.conv2d(x, w, padding=2).realize() + base = len(Tensor.conv2d(x, w, padding=2).schedule_linear().src) + with Context(WINO=1): + result = Tensor.conv2d(x, w, padding=2).realize() + wino = len(Tensor.conv2d(x, w, padding=2).schedule_linear().src) + self.assertGreater(wino, base) + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-4, rtol=1e-4) + + def test_dilation2_triggers_wino(self): + x = Tensor.rand(1,4,17,17).realize() + w = Tensor.rand(4,4,3,3).realize() + with Context(WINO=0): + expected = Tensor.conv2d(x, w, padding=2, dilation=2).realize() + base = len(Tensor.conv2d(x, w, padding=2, dilation=2).schedule_linear().src) + with Context(WINO=1): + result = Tensor.conv2d(x, w, padding=2, dilation=2).realize() + wino = len(Tensor.conv2d(x, w, padding=2, dilation=2).schedule_linear().src) + self.assertGreater(wino, base) + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-4, rtol=1e-4) + + def test_conv_transpose2d_triggers_wino(self): + x = Tensor.rand(1,4,9,9).realize() + w = Tensor.rand(4,4,3,3).realize() + with Context(WINO=0): + expected = Tensor.conv_transpose2d(x, w, stride=1, padding=0).realize() + base = len(Tensor.conv_transpose2d(x, w, stride=1, padding=0).schedule_linear().src) + with Context(WINO=1): + result = Tensor.conv_transpose2d(x, w, stride=1, padding=0).realize() + wino = len(Tensor.conv_transpose2d(x, w, stride=1, padding=0).schedule_linear().src) + self.assertGreater(wino, base) + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-4, rtol=1e-4) + + def test_conv_transpose2d_5x5_and_bias(self): + # 5x5 transposed conv (F(2x2,5x5)) plus bias + x = Tensor.rand(1,4,13,13).realize() + w = Tensor.rand(4,8,5,5).realize() + b = Tensor.rand(8).realize() + with Context(WINO=0): + expected = Tensor.conv_transpose2d(x, w, b, stride=1).realize() + base = len(Tensor.conv_transpose2d(x, w, b, stride=1).schedule_linear().src) + with Context(WINO=1): + result = Tensor.conv_transpose2d(x, w, b, stride=1).realize() + wino = len(Tensor.conv_transpose2d(x, w, b, stride=1).schedule_linear().src) + self.assertGreater(wino, base) + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-4, rtol=1e-4) + + def test_dilation2_with_bias(self): + x = Tensor.rand(1,4,17,17).realize() + w = Tensor.rand(8,4,3,3).realize() + b = Tensor.rand(8).realize() + with Context(WINO=0): expected = Tensor.conv2d(x, w, b, padding=2, dilation=2).realize() + with Context(WINO=1): result = Tensor.conv2d(x, w, b, padding=2, dilation=2).realize() + np.testing.assert_allclose(result.numpy(), expected.numpy(), atol=1e-4, rtol=1e-4) + + def test_stride2_does_not_misfire(self): + # Wino is provably not a net win for stride > 1: 36 muls/4 outputs (stride=2) vs 9 muls/output direct. + # Verify the matcher correctly declines stride=2 conv (kernel count stays at non-wino baseline). + x, w = Tensor.empty(1,4,9,9).realize(), Tensor.empty(4,4,3,3).realize() + with Context(WINO=0): base = len(Tensor.conv2d(x, w, stride=2).schedule_linear().src) + with Context(WINO=1): wino = len(Tensor.conv2d(x, w, stride=2).schedule_linear().src) + self.assertEqual(wino, base, "wino should not fire for stride=2 (provable loss vs direct)") + if __name__ == '__main__': unittest.main(verbosity=2) diff --git a/tinygrad/schedule/wino.py b/tinygrad/schedule/wino.py new file mode 100644 index 0000000000000..49f64429ef1dd --- /dev/null +++ b/tinygrad/schedule/wino.py @@ -0,0 +1,242 @@ +"""Winograd convolution as a UOp graph rewrite (http://arxiv.org/abs/1509.09308). + +Detection: walk back through movement ops with `apply_movement_op`, then verify the resulting +coordinate expressions describe a 3^n or 5^n convolution via affine analysis on RANGE UOps +(no fragile tree-shape matching). + +Supports kernel sizes 3 (F(4,3)) and 5 (F(2,5)) — both share the same alpha=6 `Bt`. Catches: +forward conv (with optional bias), grouped/depthwise/`cin==1`, mixed-dtype accumulate (top-level +CAST in reduce body), dilation=2 (rewritten as a sparse 5x5), stride-1 transposed conv (axes- +swapped + flipped weight), and the backward of any non-degenerate loss (the forward conv pattern +survives in `compute_gradient`'s output and gets caught at schedule time). + +Provably out of scope: stride > 1 (FLOP ratio `(m+K-1)^d / (s^d K^d)` < 1 only when `s < (m+K-1)/K`, +which fails for `s >= 2` with both F(4,3) and F(2,5)); kernel sizes other than 3 or 5 (need +different transform matrices); raw `dw` filter-gradient (the small object is the *output*, not +the input — needs a different Winograd algorithm). +""" +from __future__ import annotations +import itertools +from typing import TYPE_CHECKING +from tinygrad.dtype import DType, DTypeLike, Invalid +from tinygrad.helpers import WINO, flatten, flat_to_grouped, prod, resolve_pool_pads +from tinygrad.schedule.indexing import apply_movement_op +from tinygrad.uop.ops import GroupOp, Ops, PatternMatcher, UOp, UPat, graph_rewrite +from tinygrad.uop.symbolic import propagate_invalid + +if TYPE_CHECKING: + from tinygrad.tensor import Tensor + from tinygrad.uop.ops import sint + +# *** winograd math (reused tensor-level transform) *** + +winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]] +winograd_G_5 = [[1/4, 0, 0, 0, 0], [-1/6, -1/6, -1/6, -1/6, -1/6], [-1/6, 1/6, -1/6, 1/6, -1/6], + [1/24, 1/12, 1/6, 1/3, 2/3], [1/24, -1/12, 1/6, -1/3, 2/3], [0, 0, 0, 0, 1]] +winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] +winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] +winograd_At_5 = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 1]] +WINO_MATRICES = {3:(winograd_G, winograd_Bt, winograd_At), 5:(winograd_G_5, winograd_Bt, winograd_At_5)} + +def _matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]: + from tinygrad.tensor import Tensor + return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim) + for k in range(len(mat[0]))] for dim in range(dims)] + +def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor: + from tinygrad.tensor import Tensor + t_ = t.reshape(t.shape[:dims] + (1,)*dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),)*dims + t.shape[dims:]) + cols = _matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype) + ret = sum(prod(col[idx] for col, idx in zip(cols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims)) + assert isinstance(ret, Tensor), "sum didn't return a Tensor" + return ret + +def wino_conv(x_uop:UOp, weight_uop:UOp, bias_uop:UOp|None, groups:int, padding, dtype:DTypeLike|None, kernel:int) -> UOp: + from tinygrad.tensor import Tensor + x, weight = Tensor(x_uop), Tensor(weight_uop) + bias = Tensor(bias_uop) if bias_uop is not None else None + (bs, _), (cout, cin), HW = x.shape[:2], weight.shape[:2], weight.shape[2:] + G, Bt, At = WINO_MATRICES[kernel] + padding_ = resolve_pool_pads(padding, len(HW)) + rcout, oyx = cout // groups, x.pad(padding_)._pool(HW, 1, 1).shape[2:-len(HW)] + HWI, HWO = (len(Bt),) * len(HW), (len(At),) * len(HW) + pads = [(pB, pA + (-(s + pB + pA - (kernel-1)) % len(At))) for (pB, pA), s in zip(flat_to_grouped(padding_), x.shape[-len(HW):])] + d = x.pad(flatten(reversed(pads)))._pool(HWI, HWO) + d = d.permute(*range(len(d.shape)-len(HW), len(d.shape)), *range(len(d.shape)-len(HW))) + tyx = d.shape[-len(HWI):] + g = weight.permute(*range(len(weight.shape)-len(HW), len(weight.shape)), *range(len(weight.shape)-len(HW))) + gfactors = _apply_winograd_matrix(G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) + dfactors = _apply_winograd_matrix(Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx) + # contiguous() is load-bearing: without it `cin==1` fuses the whole pipeline back into a single kernel. + prod_factors = (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype).contiguous() + ret = _apply_winograd_matrix(At, prod_factors, len(HW)) + ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW), 0]]]) + ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink_to(bs, cout, *oyx) + return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1]*len(HW)))).contiguous().contiguous_backward().uop + +# *** principled detection: affine analysis + walking through movement ops *** + +# An affine expression over RANGE UOps: ({range: coefficient}, intercept) +def _strip_invalid(x:UOp) -> UOp: + x = graph_rewrite(x, propagate_invalid, name="wino affine") + return x.src[1] if x.op is Ops.WHERE and x.src[2].op is Ops.CONST and x.src[2].arg is Invalid else x + +def _affine(x:UOp) -> tuple[dict[UOp, int], int]|None: + coeffs: dict[UOp, int] = {} + intercept = 0 + for term in _strip_invalid(x).split_uop(Ops.ADD): + term = _strip_invalid(term) + if term.op is Ops.CONST and isinstance(term.arg, int): + intercept += term.arg + continue + if term.op is Ops.RANGE: + coeffs[term] = coeffs.get(term, 0) + 1 + continue + if term.op is Ops.MUL: + if all(s.op is Ops.CONST and isinstance(s.arg, int) for s in term.src): + intercept += term.src[0].arg * term.src[1].arg + continue + for rng, c in (term.src, term.src[::-1]): + if rng.op is Ops.RANGE and c.op is Ops.CONST and isinstance(c.arg, int): + coeffs[rng] = coeffs.get(rng, 0) + c.arg + break + else: return None + continue + return None + return {k:v for k,v in coeffs.items() if v}, intercept + +def _axis_coeff(c:UOp, want:UOp|None=None) -> tuple[UOp, int, int]|None: + if (a:=_affine(c)) is None or len(a[0]) != 1: return None + ax, coeff = next(iter(a[0].items())) + return (ax, coeff, a[1]) if (want is None or ax is want) else None + +def _is_axis(c:UOp, want:UOp|None=None) -> UOp|None: + """If `c` is exactly one RANGE (coeff 1, intercept 0), return it (optionally checking equality with `want`).""" + return axc[0] if (axc:=_axis_coeff(c, want)) is not None and axc[1:] == (1, 0) else None + +def _walk(x:UOp, coords:tuple[UOp, ...], dims:int) -> tuple[UOp, tuple[UOp, ...]]|None: + """Walk back through movement ops, propagating coords. Return the deepest node of shape `dims`.""" + best: tuple[UOp, tuple[UOp, ...]]|None = None + while True: + if len(x.shape) == dims: best = (x, coords) + if x.op not in GroupOp.Movement: return best + coords, x = apply_movement_op(x.op, x.src[0].shape, x.marg, coords), x.src[0] + +def _dilate_weight2(weight_uop:UOp, dims:int) -> UOp: + from tinygrad.tensor import Tensor + t = Tensor(weight_uop) + for axis in range(t.ndim-dims, t.ndim): + shp = list(t.shape); shp.insert(axis+1, 1) + t = t.reshape(tuple(shp)) + t = t.pad(tuple((0, 1) if i == axis+1 else None for i in range(t.ndim))) + shp = list(t.shape); shp[axis] *= shp.pop(axis+1) + t = t.reshape(tuple(shp)) + t = t.shrink(tuple((0, t.shape[i]-1) if i == axis else None for i in range(t.ndim))) + return t.uop + +def _match_weight(side:UOp, coords:tuple[UOp, ...], dims:int, reduced:set[UOp]): + """Weight: shape (cout, cin, *(kernel,)*dims). Coords: (cout_axis, cin_axis, *kernel_axes).""" + if (w:=_walk(side, coords, dims+2)) is None or len(set(ks:=w[0].shape[-dims:])) != 1 or ks[0] not in WINO_MATRICES: return None + weight_uop, wc = w + kaxes = tuple(_is_axis(c) for c in wc[-dims:]) + if any(k is None or k not in reduced for k in kaxes) or len(set(kaxes)) != dims: return None + if (cin_ax:=_is_axis(wc[1])) is not None: + if cin_ax not in reduced or cin_ax in kaxes: return None + elif _affine(wc[1]) != ({}, 0): return None + return weight_uop, cin_ax, kaxes, ks[0] + +def _match_tconv_weight(side:UOp, coords:tuple[UOp, ...], dims:int, reduced:set[UOp]): + """Transposed-conv weight: shape (cin, cout, *(kernel,)*dims) with flipped kernel coords.""" + from tinygrad.tensor import Tensor + if (w:=_walk(side, coords, dims+2)) is None or len(set(ks:=w[0].shape[-dims:])) != 1 or ks[0] not in WINO_MATRICES: return None + weight_uop, wc = w + if (cin_ax:=_is_axis(wc[0])) is None or cin_ax not in reduced: return None + if (cout_ax:=_is_axis(wc[1])) is None or cout_ax in reduced or cout_ax is cin_ax: return None + if any((axc:=_axis_coeff(c)) is None or axc[0] not in reduced or axc[1:] != (-1, ks[0]-1) for c in wc[-dims:]): return None + return Tensor(weight_uop).permute(1, 0, *range(2, dims+2)).flip(*range(2, dims+2)).uop, cin_ax, ks[0] + +def _match_act(side:UOp, coords:tuple[UOp, ...], dims:int, cin_ax:UOp|None, kaxes:tuple[UOp, ...], reduced:set[UOp], kernel:int): + """Activation: shape (bs, cin_total, *spatial). Coords: (bs_axis, cin_combined, *(stride*oy + dilation*ky pairs)).""" + if (w:=_walk(side, coords, dims+2)) is None: return None + x_uop, ac = w + ch = _affine(ac[1]) + blocked = reduced if cin_ax is None else reduced - {cin_ax} + if ch is None or ch[1] != 0 or any(v in blocked for v in ch[0]): return None + if cin_ax is not None and ch[0].get(cin_ax) != 1: return None + pads, oy_axes, strides, dilations = [], [], [], [] + for in_size, c, k in zip(x_uop.shape[-dims:], ac[-dims:], kaxes): + if not isinstance(in_size, int) or (a:=_affine(c)) is None or (dilation:=a[0].get(k)) is None or dilation <= 0: return None + rest = [(v, n) for v, n in a[0].items() if v is not k] + if len(rest) != 1 or (stride:=rest[0][1]) <= 0 or rest[0][0] in reduced or rest[0][0].src[0].op is not Ops.CONST: return None + oy_axes.append(rest[0][0]) + strides.append(stride); dilations.append(dilation) + pads.append((-a[1], stride*rest[0][0].src[0].arg - in_size + dilation*(kernel-1) + a[1] - stride + 1)) + if len(set(oy_axes)) != dims or len(set(strides)) != 1 or len(set(dilations)) != 1: return None + return x_uop, tuple(flatten(reversed(pads))), strides[0], dilations[0] + +def _match_bias(side:UOp, coords:tuple[UOp, ...], ch_axis:UOp) -> UOp|None: + if (w:=_walk(side, coords, dims=1)) is None: return None + return w[0] if _is_axis(w[1][0], want=ch_axis) is not None else None + +def _detect_conv(reduce:UOp) -> tuple[UOp, UOp, int, tuple[int, ...], int]|None: + """Returns (x, weight, groups, padding, kernel) if `reduce` is a conv-style REDUCE-of-MUL, else None.""" + # MUL shape from OpMixin.conv2d is (bs, groups, rcout, *oyx, cin, *HW), so dims comes from the shape itself + if reduce.arg[0] is not Ops.ADD: return None + mul = reduce.src[0].src[0] if reduce.src[0].op is Ops.CAST and reduce.src[0].src[0].op is Ops.MUL else reduce.src[0] + if mul.op is not Ops.MUL or (dims:=(len(mul.shape)-4)//2) <= 0: return None + coords = tuple(UOp.range(s, i) for i, s in enumerate(mul.shape)) + reduced = {coords[i] for i in reduce.arg[1]} + for wt_side, act_side in (mul.src, mul.src[::-1]): + if (wt:=_match_weight(wt_side, coords, dims, reduced)) is None: continue + if (act:=_match_act(act_side, coords, dims, wt[1], wt[2], reduced, wt[3])) is None: continue + if act[2] != 1: return None + weight_uop, x_uop, padding, kernel = wt[0], act[0], act[1], wt[3] + if act[3] == 2: + if kernel != 3: return None + weight_uop, kernel = _dilate_weight2(weight_uop, dims), 5 + elif act[3] != 1: return None + cin, chans, cout = weight_uop.shape[1], x_uop.shape[1], weight_uop.shape[0] + if not all(isinstance(v, int) and v > 0 for v in (cin, chans, cout)) or chans % cin: return None + if cout % (groups:=chans // cin): return None + return x_uop, weight_uop, groups, padding, kernel + return None + +def _detect_tconv(reduce:UOp) -> tuple[UOp, UOp, int, tuple[int, ...], int]|None: + if reduce.arg[0] is not Ops.ADD or reduce.src[0].op is not Ops.MUL or (dims:=(len(reduce.src[0].shape)-4)//2) <= 0: return None + coords = tuple(UOp.range(s, i) for i, s in enumerate(reduce.src[0].shape)) + reduced = {coords[i] for i in reduce.arg[1]} + for wt_side, act_side in (reduce.src[0].src, reduce.src[0].src[::-1]): + if (wt:=_match_tconv_weight(wt_side, coords, dims, reduced)) is None: continue + if (act:=_match_act(act_side, coords, dims, wt[1], tuple(coords[i] for i in reduce.arg[1][1:]), reduced, wt[2])) is None or act[2:] != (1, 1): continue + cin, chans, cout = wt[0].shape[1], act[0].shape[1], wt[0].shape[0] + if not all(isinstance(v, int) and v > 0 for v in (cin, chans, cout)) or chans % cin: return None + if cout % (groups:=chans // cin): return None + return act[0], wt[0], groups, act[1], wt[2] + return None + +def _detect_wino(reduce:UOp) -> tuple[UOp, UOp, int, tuple[int, ...], int]|None: return _detect_conv(reduce) or _detect_tconv(reduce) + +def _try_wino_reduce(reduce:UOp) -> UOp|None: + if not WINO.value or (det:=_detect_wino(reduce)) is None: return None + return wino_conv(det[0], det[1], None, det[2], det[3], reduce.dtype, det[4]) + +def _find_reduce(x:UOp) -> UOp|None: + """Walk back through movement ops to find a REDUCE.""" + while x.op in GroupOp.Movement: x = x.src[0] + return x if x.op is Ops.REDUCE else None + +def _try_wino_add(add:UOp) -> UOp|None: + """ADD(conv-reduce-output via movement ops, bias-broadcast).""" + if not WINO.value or len(add.shape) < 2: return None + for reduce_side, bias_side in (add.src, add.src[::-1]): + if (reduce:=_find_reduce(reduce_side)) is None or (det:=_detect_wino(reduce)) is None: continue + # synthetic axis vars for symbolically walking the bias side; the high index avoids clashing with real range IDs + fresh = tuple(UOp.range(s, 9001+i) for i, s in enumerate(add.shape)) + if (bias:=_match_bias(bias_side, fresh, fresh[1])) is None: continue + return wino_conv(det[0], det[1], bias, det[2], det[3], reduce.dtype, det[4]) + return None + +pm_wino = PatternMatcher([ + (UPat(Ops.ADD, name="add"), _try_wino_add), + (UPat(Ops.REDUCE, name="reduce"), _try_wino_reduce), +]) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index 720a00487fcc8..1615bf8a12a3f 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1,12 +1,12 @@ # inspired by https://github.com/karpathy/micrograd/blob/master/micrograd/engine.py from __future__ import annotations -import time, math, itertools, functools, struct, sys, inspect, pathlib, hashlib, weakref +import time, math, functools, struct, sys, inspect, pathlib, hashlib, weakref from contextlib import ContextDecorator from typing import Any, Callable, ClassVar, Sequence, cast, get_args, ParamSpec, TypeVar, Generic, TYPE_CHECKING if TYPE_CHECKING: import numpy from tinygrad.dtype import DType, DTypeLike, dtypes, ConstType, least_upper_dtype, to_dtype, truncate from tinygrad.dtype import _from_np_dtype, _to_np_dtype, PyConst, Invalid -from tinygrad.helpers import argfix, flatten, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch, flat_to_grouped +from tinygrad.helpers import argfix, prod, all_int, round_up, getenv, all_same, fully_flatten, ceildiv, fetch from tinygrad.helpers import resolve_pool_pads, IMAGE, FLOAT16, WINO, Metadata, TRACEMETA, is_numpy_ndarray, TracingKey, cpu_profile from tinygrad.helpers import suppress_finalizing, disable_gc from tinygrad.gradient import compute_gradient @@ -36,6 +36,14 @@ def visitor(node: UOp) -> bool: return True if node in applied_map else any(in_s if s is ns: continue t.uop = ns +def _apply_wino(roots:list[Tensor]) -> None: + from tinygrad.schedule.wino import pm_wino + from tinygrad.uop.ops import graph_rewrite + sink = UOp.sink(*[t.uop for t in roots]) + new_sink = graph_rewrite(sink, pm_wino, bottom_up=True, name="pre schedule wino") + if (m:={old:new for old,new in zip(sink.src, new_sink.src) if old is not new}): + _apply_map_to_tensors(m, name="wino", walk=True) + # **** Tensor helper functions **** def _fromnp(x: 'numpy.ndarray') -> UOp: @@ -61,22 +69,6 @@ def _frompy(x:list|tuple|bytes, dtype:DType, device:str|tuple[str,...]) -> UOp: ret.buffer.allocate(memoryview(data if device != "PYTHON" else bytearray(data))) return ret -def _get_winograd_matcols(mat, dims:int, shp:tuple[sint, ...], device:str|tuple[str, ...], dtype:DType) -> list[list[Tensor]]: - return [[Tensor.cat(*[Tensor.full(shp[:dim] + (1,) + shp[dim+1:], float(m[k]), device=device, dtype=dtype) for m in mat], dim=dim) - for k in range(len(mat[0]))] for dim in range(dims)] - -# winograd conv 3 kernel f(4x4,3x3) see: http://arxiv.org/abs/1509.09308 -def _apply_winograd_matrix(mat, t:Tensor, dims:int) -> Tensor: - # multiply mat_1 @ mat_2 @ t with foldable constants, where mat_i acts on vector t along dimension i; roughly kron(mat, mat) @ t - # due to realize-before-expand rule in lazy.py, we must operate in this order: reshape -> expand -> arithmetic - t_ = t.reshape(t.shape[:dims] + (1,) * dims + t.shape[dims:]).expand(t.shape[:dims] + (len(mat),) * dims + t.shape[dims:]) # add output dims - # precalculate mat columns for each dim; prod(itertools.product(matcols)) gives the columns of kron(mat, mat, ...) - matcols = _get_winograd_matcols(mat, dims, t_.shape[dims:], t_.device, t_.dtype) - # multiply each element of t_ by the corresponding stacked column of kron(mat, mat), producing only one view for each element of t - ret = sum(prod(col[idx] for col, idx in zip(matcols, mat_is)) * t_[mat_is] for mat_is in itertools.product(range(len(mat[0])), repeat=dims)) - assert isinstance(ret, Tensor), "sum didn't return a Tensor" - return ret - class Tensor(OpMixin): """ A `Tensor` is a multi-dimensional matrix containing elements of a single data type. @@ -228,6 +220,7 @@ def callify(self, *lst:Tensor) -> Tensor: def linear_with_vars(self, *lst:Tensor) -> tuple[UOp, dict[str, int]]: """Creates the LINEAR UOp needed to realize these Tensor(s), with Variables.""" + if WINO.value: _apply_wino([self, *lst]) big_sink, becomes_map = transform_to_call(UOp.sink(*[x.uop for x in (self,)+lst])) _apply_map_to_tensors(becomes_map, name="buffers") return create_linear_with_vars(big_sink) @@ -1195,45 +1188,6 @@ def hash(self) -> Tensor: # ***** processing ops ***** - # TODO: winograd can be a rewrite rule like split_reduceop - def _conv2d_winograd(self, weight:Tensor, bias:Tensor|None, groups:int, padding:int|Sequence[int], dtype:DTypeLike|None) -> Tensor: - (bs,cin_), (cout,cin), HW = self.shape[:2], weight.shape[:2], weight.shape[2:] - padding_ = resolve_pool_pads(padding, len(HW)) - assert groups*cin == cin_ and len(self.shape) == len(weight.shape),\ - f"Input Tensor shape {self.shape} does not match the shape of the weights {weight.shape}. ({groups*cin} vs. {cin_})" - rcout, oyx = cout//groups, self.pad(padding_)._pool(HW, 1, 1).shape[2:-len(HW)] - HWI, HWO = (6,) * len(HW), (4,) * len(HW) # F(4x4,3x3) winograd tiles - winograd_G = [[1/4, 0, 0], [-1/6, -1/6, -1/6], [-1/6, 1/6, -1/6], [1/24, 1/12, 1/6], [1/24, -1/12, 1/6], [0, 0, 1]] - winograd_Bt = [[4, 0, -5, 0, 1, 0], [0, -4, -4, 1, 1, 0], [0, 4, -4, -1, 1, 0], [0, -2, -1, 2, 1, 0], [0, 2, -1, -2, 1, 0], [0, 4, 0, -5, 0, 1]] - winograd_At = [[1, 1, 1, 1, 1, 0], [0, 1, -1, 2, -2, 0], [0, 1, 1, 4, 4, 0], [0, 1, -1, 8, -8, 1]] # applying At in pre-order doubles compile time - - # TODO: stride == dilation - # use padding to round up to 4x4 output tiles - # (bs, cin_, tyx, HWI) - pads = [(pB, pA + (-(s + pB + pA - 2) % 4)) for (pB, pA), s in zip(flat_to_grouped(padding_), self.shape[-len(HW):])] - d = self.pad(flatten(reversed(pads)))._pool(HWI, HWO) - # move HW to the front: # (HWI, bs, cin_, tyx) - d = d.permute(*range(len(d.shape)-len(HW),len(d.shape)), *range(len(d.shape)-len(HW))) - tyx = d.shape[-len(HWI):] # dim of tiling - - g = weight.permute(*range(len(weight.shape)-len(HW),len(weight.shape)), *range(len(weight.shape)-len(HW))) # move HW to the front - - # compute 6x6 winograd tiles: GgGt, BtdB - # (HWI, groups * rcout, cin) -> (HWI, bs=1, groups, rcout, cin, tyx=(1,1)) - gfactors = _apply_winograd_matrix(winograd_G, g, len(HW)).reshape(*HWI, 1, groups, rcout, cin, *([1]*len(tyx))) - # (HWI, bs, cin_, tyx) -> (HWI, bs, groups, 1 ,cin, *tyx) - dfactors = _apply_winograd_matrix(winograd_Bt, d, len(HW)).reshape(*HWI, bs, groups, 1, cin, *tyx) - - # matmul; sum across cin: (HWI, bs, groups, rcout, *tyx); then HWI -> HWO: (HWO, bs, groups, rcout, *tyx) - ret = _apply_winograd_matrix(winograd_At, (gfactors * dfactors).sum(axis=-1-len(HW), dtype=dtype), len(HW)) - - # interleave tyx and HWO: (bs, groups, rcout, oy, HO, ox, WO) - ret = ret.permute([*range(len(HW), len(ret.shape)-len(HW)), *[i+o for i in range(len(HW)) for o in [len(ret.shape)-len(HW),0]]]) - # merge groups and rcout, tyx and HWO: (bs, groups, cout, *yx), shrink to final - ret = ret.reshape(bs, cout, *[c * HWO[i] for i, c in enumerate(tyx)]).shrink_to(bs, cout, *oyx) - - return (ret if bias is None else ret.add(bias.reshape(1, -1, *[1 for _ in range(len(HW))]))).contiguous().contiguous_backward() - def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding:int|Sequence[int]=0, dtype:DTypeLike|None=None) -> Tensor: """ @@ -1262,7 +1216,6 @@ def conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilat ``` """ if IMAGE: return self.image_conv2d(weight, bias, groups, stride, dilation, padding, dtype) - if WINO and all(x == 3 for x in weight.shape[2:]) and stride == dilation == 1: return self._conv2d_winograd(weight, bias, groups, padding, dtype) return super().conv2d(weight, bias, groups, stride, dilation, padding, dtype) def dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: