From b6c424951d7cd443478968e94c9f7643f4555301 Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Wed, 13 May 2026 10:20:28 -0700 Subject: [PATCH] minor image_conv2d cleanup remove some no-op slices --- tinygrad/tensor.py | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index d0807939ab348..a25c4afc88b01 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1477,7 +1477,7 @@ def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor: def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor: dtsz = 2 if FLOAT16 else 4 - (bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape + (bs,_,_,_), (cout,cin,H,W) = self.shape, weight.shape assert isinstance(cin, int) and isinstance(cout, int) x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W) @@ -1513,7 +1513,6 @@ def is_pow2(v): return v > 0 and v & (v - 1) == 0 def ipad(t, i, amt): shape = (None,)*i + (amt,) + (None,)*(t.ndim-i-1) return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid) if amt != t.shape[i] else t - # align a dimension, use at to specify the dimension to pad in, defaults to first def pad_align(t, dim, at=None, force=False): # align to 64 pixels when height is real, otherwise 64 bytes is sufficient @@ -1531,7 +1530,7 @@ def pad_align(t, dim, at=None, force=False): else: x, w = x.contiguous(), w.contiguous() # undo alignment hacks - if bank_conflict: x, w = x[:, :, :ix, :, :cin // 4, :], w[:, :H, :cin // 4, ...] + if bank_conflict: x, w = x[:, :, :, :, :cin // 4, :], w[:, :, :cin // 4, ...] else: x, w = x[:, :, :ix, :], w[:, :H, ...] # expand out @@ -1554,13 +1553,9 @@ def pad_align(t, dim, at=None, force=False): # the conv! ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype) - if added_ox: - ret = ret.reshape(bs, oy, ox + added_ox, groups, rcout)[:, :, :ox, ...] - + ret = ret.reshape(bs, oy, ox + added_ox, groups, rcout)[:, :, :ox, :, :] # undo hack for non multiples of 4 on C.rcout - if added_output_channels: - ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels] - + if added_output_channels: ret = ret[:, :, :, :, :-added_output_channels] # NCHW output ret = ret.reshape(bs, oy, ox, groups * (rcout - added_output_channels)).permute(0,3,1,2) return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))