From 12c669d27d0677e2d1456a049a211af1697090cc Mon Sep 17 00:00:00 2001 From: Chen-Yu Yang Date: Wed, 13 May 2026 16:46:48 -0700 Subject: [PATCH] minor hash cleanups same kernels --- tinygrad/tensor.py | 34 +++++++++++++--------------------- 1 file changed, 13 insertions(+), 21 deletions(-) diff --git a/tinygrad/tensor.py b/tinygrad/tensor.py index a25c4afc88b01..c5f2bcfb6346c 100644 --- a/tinygrad/tensor.py +++ b/tinygrad/tensor.py @@ -1146,12 +1146,13 @@ def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64): 0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)] rate, dsbyte = {"sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31)}[cfg] if isinstance(cfg, str) else cfg - data, data_pad = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1]), rate - (self.shape[-1] * self.dtype.itemsize % rate) + data = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1]) + data_pad = rate - data.shape[-1] % rate # pad batches then pad blocks - data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad((None, None, (0, 200 - rate))) + data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad_to(None, None, 200) # create pad mask - lbe = prod(data.shape[1:]) + rate - data_pad - 200 + lbe = (data.shape[1] - 1) * 200 + rate - data_pad if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (200 - rate, 0)] else: mb = [(lbe, 0), (1, dsbyte), (data_pad - 2, 0), (1, 0x80), (200 - rate, 0)] pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)).unsqueeze(0) @@ -1160,7 +1161,7 @@ def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64): state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64) for k in range(int(data.shape[1])): - state = state ^ data.shrink((None, (k, k+1), None)).squeeze(1) + state = state ^ data[:, k] for i in range(24): # f1600 # θ step p = state.reshape(bs, 5, 5).transpose(2, 1) @@ -1179,11 +1180,7 @@ def _hash_1mb(self) -> Tensor: assert self.dtype == dtypes.uint8, "only support uint8 tensors for hashing" assert self.ndim == 2, "only support batched 1d tensors" assert self.shape[1] == 1024 * 1024, "only support messages of 1mb" - - blocks = self.shape[0] * self.shape[1] // 4096 - data = self.reshape(blocks, 4096) - block_hashes = data.keccak("shake_128").reshape(self.shape[0], 4096) - return block_hashes.keccak("shake_128").reshape(self.shape[0], 16) + return self.reshape(-1, 4096).keccak("shake_128").reshape(self.shape[0], -1).keccak("shake_128") def hash(self) -> Tensor: """ @@ -1193,19 +1190,14 @@ def hash(self) -> Tensor: print(t.data().hex()) ``` """ - data = self.flatten().bitcast(dtypes.uint8) - if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20)) - base_chunks = ceildiv(data.shape[0], 2**20) - tree_depth = math.ceil(math.log(base_chunks, 65536)) if base_chunks > 1 else 0 - - level_chunks = base_chunks - for _ in range(tree_depth + 1): - data = data.reshape(level_chunks, 2**20)._hash_1mb().flatten() - if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20)) - level_chunks = ceildiv(data.shape[0], 2**20) - - return data[:16] + n = data.shape[0] + assert isinstance(n, int), "hash requires concrete shape" + chunks = ceildiv(n, 2**20) + while chunks > 1: + data = data.pad_to(chunks * 2**20).reshape(chunks, 2**20)._hash_1mb().flatten() + chunks = ceildiv(chunks, 65536) + return data.pad_to(2**20).unsqueeze(0)._hash_1mb().flatten()[:16] # ***** processing ops *****