Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1477,7 +1477,7 @@ def image_dot(self, w:Tensor, dtype:DTypeLike|None=None) -> Tensor:
def image_conv2d(self, weight:Tensor, bias:Tensor|None=None, groups=1, stride=1, dilation=1, padding=0, dtype=None) -> Tensor:
dtsz = 2 if FLOAT16 else 4

(bs,_,iy,ix), (cout,cin,H,W) = self.shape, weight.shape
(bs,_,_,_), (cout,cin,H,W) = self.shape, weight.shape
assert isinstance(cin, int) and isinstance(cout, int)
x, w = self, weight.reshape(groups, (rcout := cout//groups), cin, H, W)

Expand Down Expand Up @@ -1513,7 +1513,6 @@ def is_pow2(v): return v > 0 and v & (v - 1) == 0
def ipad(t, i, amt):
shape = (None,)*i + (amt,) + (None,)*(t.ndim-i-1)
return Tensor(True, device=t.device).expand(t.shape).pad_to(shape).where(t.pad_to(shape), Invalid) if amt != t.shape[i] else t

# align a dimension, use at to specify the dimension to pad in, defaults to first
def pad_align(t, dim, at=None, force=False):
# align to 64 pixels when height is real, otherwise 64 bytes is sufficient
Expand All @@ -1531,7 +1530,7 @@ def pad_align(t, dim, at=None, force=False):
else: x, w = x.contiguous(), w.contiguous()

# undo alignment hacks
if bank_conflict: x, w = x[:, :, :ix, :, :cin // 4, :], w[:, :H, :cin // 4, ...]
if bank_conflict: x, w = x[:, :, :, :, :cin // 4, :], w[:, :, :cin // 4, ...]
else: x, w = x[:, :, :ix, :], w[:, :H, ...]

# expand out
Expand All @@ -1554,13 +1553,9 @@ def pad_align(t, dim, at=None, force=False):
# the conv!
ret = (x*w).cast(dtypes.float32).sum((-4, -3, -2, -1), dtype=dtype)

if added_ox:
ret = ret.reshape(bs, oy, ox + added_ox, groups, rcout)[:, :, :ox, ...]

ret = ret.reshape(bs, oy, ox + added_ox, groups, rcout)[:, :, :ox, :, :]
# undo hack for non multiples of 4 on C.rcout
if added_output_channels:
ret = ret.reshape(bs, oy, ox, groups, rcout)[:, :, :, :, :-added_output_channels]

if added_output_channels: ret = ret[:, :, :, :, :-added_output_channels]
# NCHW output
ret = ret.reshape(bs, oy, ox, groups * (rcout - added_output_channels)).permute(0,3,1,2)
return ret if bias is None else ret.add(bias.reshape(1, -1, 1, 1))
Expand Down
Loading