Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
42 changes: 22 additions & 20 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1411,20 +1411,19 @@ def scaled_dot_product_attention(self, key:Tensor, value:Tensor, attn_mask:Tenso
def qr(self) -> tuple[Tensor, Tensor]:
assert self.ndim > 1, f"expected two or more dimensions, got {self.ndim}"
b_shape, m, n = self.shape[:-2], int(self.shape[-2]), int(self.shape[-1])
R = self.clone()
Q = Tensor.eye(m, dtype=self.dtype, device=self.device).expand(b_shape + (m, m))
R, Q = self, Tensor.eye(m, dtype=self.dtype, device=self.device).expand(b_shape + (m, m))
idx = Tensor.arange(m, device=self.device)
for i in range(min(m, n)):
x = R[..., i:m, i]
norm = x.square().sum(-1).sqrt()
mask = norm != 0
s = (x[..., 0] != 0).where(-x[..., 0].sign(), -1)
u1 = x[..., 0] - s * norm
w = x.unsqueeze(-1) / mask.where(u1, 1)[..., None, None]
w[..., 0, 0] = 1
tau = (-s * u1 / mask.where(norm, 1))[..., None, None]
tau = mask[..., None, None].where(tau, 0)
R[..., i:m, :] = R[..., i:m, :] - (w * tau) @ (w.transpose(-2, -1) @ R[..., i:m, :])
Q[..., :, i:m] = Q[..., :, i:m] - (Q[..., :, i:m] @ w) @ (tau * w).transpose(-2, -1)
# full-length Householder reflector v with zeros above row i; w = tau*v is the rank-1 update factor
at_i, x = idx == i, (idx >= i).where(R[..., :, i], 0)
norm = x.square().sum(-1, keepdim=True).sqrt()
x0 = at_i.where(x, 0).sum(-1, keepdim=True)
sgn, active = (x0 != 0).where(x0.sign(), 1), norm != 0
u0 = x0 + sgn * norm
v = (at_i.where(u0, x) / active.where(u0, 1)).unsqueeze(-1)
w = active.where(sgn * u0 / active.where(norm, 1), 0).unsqueeze(-1) * v
R = R - w @ (v.transpose(-2, -1) @ R)
Q = Q - (Q @ v) @ w.transpose(-2, -1)
return Q, R

def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]:
Expand All @@ -1437,10 +1436,11 @@ def svd(self, full_matrices = True) -> tuple[Tensor, Tensor, Tensor]:
# TODO: codegen infinite loop without contiguous
U = R[..., :num, :num].contiguous()
V = Tensor.eye(num, dtype=self.dtype, device=self.device).expand(b_shape + (num, num)).contiguous()
#prepare round robin pairing
permute, inverse_permute = Tensor.arange(0, num, dtype=dtypes.int, device=self.device), Tensor.zeros(num, dtype=dtypes.int, device=self.device)
permute[num//2:num] = permute[num//2:num].flip(0)
inverse_permute[permute] = Tensor.arange(num, dtype=dtypes.int, device=self.device)
#prepare round robin pairing: identity on first half, reversed on second half
permute = Tensor.arange(num//2, dtype=dtypes.int, device=self.device).cat(
Tensor.arange(num//2, num, dtype=dtypes.int, device=self.device).flip(0))
inverse_permute = Tensor.zeros(num, dtype=dtypes.int, device=self.device).scatter(
0, permute, Tensor.arange(num, dtype=dtypes.int, device=self.device))
def one_round_jacobi(U, V, permute, inverse_permute):
#pair all the columns
V_permuted, runoff_V = (V[..., permute].split(num - 1, -1)) if num % 2 == 1 else (V[..., permute], None)
Expand Down Expand Up @@ -1473,9 +1473,11 @@ def one_round_jacobi(U, V, permute, inverse_permute):
new_indices = indices.unsqueeze(-2).expand(b_shape + (num, num))
U = U.gather(-1, new_indices) / (S != 0).where(S, 1).unsqueeze(-2)
V = V.gather(-1, new_indices)
padded_u = Tensor.eye(q_num, dtype=U.dtype, device=U.device).expand(b_shape + (q_num, q_num))
padded_u[..., 0:num, 0:num] = U
U = Q @ padded_u
# place U into the top-left num×num block of a q_num×q_num identity matrix
pad_arg = (None,) * len(b_shape) + ((0, q_num - num), (0, q_num - num))
eye_q = Tensor.eye(q_num, dtype=U.dtype, device=U.device).expand(b_shape + (q_num, q_num))
eye_n = Tensor.eye(num, dtype=U.dtype, device=U.device).expand(b_shape + (num, num)).pad(pad_arg)
U = Q @ (U.pad(pad_arg) + eye_q - eye_n)
if not full_matrices: U = U[..., 0:num]
return (U, S, V.transpose(-2, -1)) if m >= n else (V, S, U.transpose(-2, -1))

Expand Down
Loading