From 2cead791fd4762ca4ada7e94a4eeae63bb9f364f Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Tue, 12 May 2026 15:57:05 -0700 Subject: [PATCH] functional qr and svd no clone and setitem, will move to mixin next. slightly faster but still quite slow --- tinygrad/tensor.py | 42 ++++++++++++++++++++++-------------------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index ff98baa2e10ee..fb455c7fb0dd2 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1411,20 +1411,19 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tenso def qr(self) -> tuple[Tensor, Tensor]: assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}" b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1]) - R = self.clone() - Q = Tensor.eye(m, dtype=self.dtype, device=self.device).expand(b_shape + (m, m)) + R, Q = self, Tensor.eye(m, dtype=self.dtype, device=self.device).expand(b_shape + (m, m)) + idx = Tensor.arange(m, device=self.device) for i in range(min(m, n)): - x = R[..., i:m, i] - norm = x.square().sum(-1).sqrt() - mask = norm != 0 - s = (x[..., 0] != 0).where(-x[..., 0].sign(), -1) - u1 = x[..., 0] - s * norm - w = x.unsqueeze(-1) / mask.where(u1, 1)[..., None, None] - w[..., 0, 0] = 1 - tau = (-s * u1 / mask.where(norm, 1))[..., None, None] - tau = mask[..., None, None].where(tau, 0) - R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :]) - Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1) + # full-length Householder reflector v with zeros above row i; w = tau*v is the rank-1 update factor + at_i, x = idx == i, (idx >= i).where(R[..., :, i], 0) + norm = x.square().sum(-1, keepdim=True).sqrt() + x0 = at_i.where(x, 0).sum(-1, keepdim=True) + sgn, active = (x0 != 0).where(x0.sign(), 1), norm != 0 + u0 = x0 + sgn * norm + v = (at_i.where(u0, x) / active.where(u0, 1)).unsqueeze(-1) + w = active.where(sgn * u0 / active.where(norm, 1), 0).unsqueeze(-1) * v + R = R - w @ (v.transpose(-2, -1) @ R) + Q = Q - (Q @ v) @ w.transpose(-2, -1) return Q, R def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]: @@ -1437,10 +1436,11 @@ def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]: # TODO: codegen infinite loop without contiguous U = R[..., :num, :num].contiguous() V = Tensor.eye(num, dtype=self.dtype, device=self.device).expand(b_shape + (num, num)).contiguous() - #prepare round robin pairing - permute, inverse_permute = Tensor.arange(0, num, dtype=dtypes.int, device=self.device), Tensor.zeros(num, dtype=dtypes.int, device=self.device) - permute[num//2:num] = permute[num//2:num].flip(0) - inverse_permute[permute] = Tensor.arange(num, dtype=dtypes.int, device=self.device) + #prepare round robin pairing: identity on first half, reversed on second half + permute = Tensor.arange(num//2, dtype=dtypes.int, device=self.device).cat( + Tensor.arange(num//2, num, dtype=dtypes.int, device=self.device).flip(0)) + inverse_permute = Tensor.zeros(num, dtype=dtypes.int, device=self.device).scatter( + 0, permute, Tensor.arange(num, dtype=dtypes.int, device=self.device)) def one_round_jacobi(U, V, permute, inverse_permute): #pair all the columns V_permuted, runoff_V = (V[..., permute].split(num - 1, -1)) if num % 2 == 1 else (V[..., permute], None) @@ -1473,9 +1473,11 @@ def one_round_jacobi(U, V, permute, inverse_permute): new_indices = indices.unsqueeze(-2).expand(b_shape + (num, num)) U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2) V = V.gather(-1, new_indices) - padded_u = Tensor.eye(q_num, dtype=U.dtype, device=U.device).expand(b_shape + (q_num, q_num)) - padded_u[..., 0:num, 0:num] = U - U = Q @ padded_u + # place U into the top-left num×num block of a q_num×q_num identity matrix + pad_arg = (None,) * len(b_shape) + ((0, q_num - num), (0, q_num - num)) + eye_q = Tensor.eye(q_num, dtype=U.dtype, device=U.device).expand(b_shape + (q_num, q_num)) + eye_n = Tensor.eye(num, dtype=U.dtype, device=U.device).expand(b_shape + (num, num)).pad(pad_arg) + U = Q @ (U.pad(pad_arg) + eye_q - eye_n) if not full_matrices: U = U[..., 0:num] return (U, S, V.transpose(-2, -1)) if m >= n else (V, S, U.transpose(-2, -1))