Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 13 additions & 21 deletions tinygrad/tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -1146,12 +1146,13 @@ def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64):
0x8000000000008002, 0x8000000000000080, 0x800a, 0x800000008000000a, 0x8000000080008081, 0x8000000000008080, 0x80000001, 0x8000000080008008)]

rate, dsbyte = {"sha3_224": (144, 6), "sha3_256": (136, 6), "shake_128": (168, 31)}[cfg] if isinstance(cfg, str) else cfg
data, data_pad = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1]), rate - (self.shape[-1] * self.dtype.itemsize % rate)
data = self.bitcast(dtypes.uint8).reshape(prod(self.shape[:-1]), self.shape[-1])
data_pad = rate - data.shape[-1] % rate
# pad batches then pad blocks
data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad((None, None, (0, 200 - rate)))
data = data.pad((None, (0, data_pad))).reshape(bs := data.shape[0], -1, rate).pad_to(None, None, 200)

# create pad mask
lbe = prod(data.shape[1:]) + rate - data_pad - 200
lbe = (data.shape[1] - 1) * 200 + rate - data_pad
if data_pad == 1: mb = [(lbe, 0), (1, dsbyte ^ 0x80), (200 - rate, 0)]
else: mb = [(lbe, 0), (1, dsbyte), (data_pad - 2, 0), (1, 0x80), (200 - rate, 0)]
pad_mask = Tensor.cat(*(Tensor(v, dtype=dtypes.uint8, device=data.device).expand(l) for l, v in mb if l > 0)).unsqueeze(0)
Expand All @@ -1160,7 +1161,7 @@ def ctensor(l: Sequence[PyConst], dtype: DType = dtypes.uint64):

state = Tensor.zeros(bs, 25, device=self.device, dtype=dtypes.uint64)
for k in range(int(data.shape[1])):
state = state ^ data.shrink((None, (k, k+1), None)).squeeze(1)
state = state ^ data[:, k]
for i in range(24): # f1600
# θ step
p = state.reshape(bs, 5, 5).transpose(2, 1)
Expand All @@ -1179,11 +1180,7 @@ def _hash_1mb(self) -> Tensor:
assert self.dtype == dtypes.uint8, "only support uint8 tensors for hashing"
assert self.ndim == 2, "only support batched 1d tensors"
assert self.shape[1] == 1024 * 1024, "only support messages of 1mb"

blocks = self.shape[0] * self.shape[1] // 4096
data = self.reshape(blocks, 4096)
block_hashes = data.keccak("shake_128").reshape(self.shape[0], 4096)
return block_hashes.keccak("shake_128").reshape(self.shape[0], 16)
return self.reshape(-1, 4096).keccak("shake_128").reshape(self.shape[0], -1).keccak("shake_128")

def hash(self) -> Tensor:
"""
Expand All @@ -1193,19 +1190,14 @@ def hash(self) -> Tensor:
print(t.data().hex())
```
"""

data = self.flatten().bitcast(dtypes.uint8)
if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20))
base_chunks = ceildiv(data.shape[0], 2**20)
tree_depth = math.ceil(math.log(base_chunks, 65536)) if base_chunks > 1 else 0

level_chunks = base_chunks
for _ in range(tree_depth + 1):
data = data.reshape(level_chunks, 2**20)._hash_1mb().flatten()
if (tsize := data.shape[0]) % 2**20 != 0: data = data.pad((0, 2**20 - tsize % 2**20))
level_chunks = ceildiv(data.shape[0], 2**20)

return data[:16]
n = data.shape[0]
assert isinstance(n, int), "hash requires concrete shape"
chunks = ceildiv(n, 2**20)
while chunks > 1:
data = data.pad_to(chunks * 2**20).reshape(chunks, 2**20)._hash_1mb().flatten()
chunks = ceildiv(chunks, 65536)
return data.pad_to(2**20).unsqueeze(0)._hash_1mb().flatten()[:16]

# ***** processing ops *****

Expand Down
Loading