-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathturboquant.py
More file actions
422 lines (342 loc) · 15.5 KB
/
turboquant.py
File metadata and controls
422 lines (342 loc) · 15.5 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
"""
TurboQuant: Online Vector Quantization with Near-optimal Distortion Rate
=========================================================================
A Python implementation based on the ICLR 2026 paper by Zandieh et al.
(arXiv: 2504.19874)
TurboQuant compresses high-dimensional vectors (e.g., KV cache entries) using:
Stage 1 (PolarQuant): Random rotation + Lloyd-Max scalar quantization
Stage 2 (QJL): 1-bit residual correction for unbiased inner products
Usage:
python turboquant.py # run self-test + demo
python turboquant.py --bits 3 # specify bit-width
"""
import numpy as np
from scipy.stats import beta as beta_dist
from scipy.optimize import minimize_scalar
from dataclasses import dataclass
from typing import Optional
import time
# =============================================================================
# 1. Lloyd-Max Quantizer for Beta Distribution
# =============================================================================
def build_lloyd_max_codebook(
dimension: int,
bits: int,
max_iters: int = 200,
tol: float = 1e-10,
) -> tuple[np.ndarray, np.ndarray]:
"""
Build a Lloyd-Max scalar quantizer for the Beta(d/2, d/2) distribution
that arises after random orthogonal rotation of unit-norm vectors.
After rotation, each coordinate of a unit vector follows Beta(d/2, d/2)
shifted to [-1, 1]. We quantize on [-1, 1].
Returns:
centroids: array of shape (2^bits,) — reconstruction levels
boundaries: array of shape (2^bits + 1,) — decision boundaries
"""
n_levels = 2 ** bits
alpha = dimension / 2.0
# Beta(a, a) on [0,1] → shifted to [-1, 1] via x = 2*u - 1
# PDF on [-1,1]: f(x) = beta_pdf((x+1)/2, a, a) / 2
def pdf(x):
u = (x + 1.0) / 2.0
u = np.clip(u, 1e-15, 1 - 1e-15)
return beta_dist.pdf(u, alpha, alpha) / 2.0
def cdf(x):
u = (x + 1.0) / 2.0
u = np.clip(u, 0.0, 1.0)
return beta_dist.cdf(u, alpha, alpha)
# Initialize centroids via quantiles (uniform in CDF space)
centroids = np.array([
2.0 * beta_dist.ppf((i + 0.5) / n_levels, alpha, alpha) - 1.0
for i in range(n_levels)
])
# Lloyd-Max iteration
for iteration in range(max_iters):
# Compute boundaries as midpoints between centroids
boundaries = np.empty(n_levels + 1)
boundaries[0] = -1.0
boundaries[-1] = 1.0
for i in range(1, n_levels):
boundaries[i] = (centroids[i - 1] + centroids[i]) / 2.0
# Update centroids: centroid[i] = E[X | boundaries[i] <= X < boundaries[i+1]]
new_centroids = np.empty(n_levels)
for i in range(n_levels):
lo, hi = boundaries[i], boundaries[i + 1]
# Numerical integration via quadrature on the shifted beta
n_quad = 500
xs = np.linspace(lo, hi, n_quad)
ps = pdf(xs)
mass = np.trapezoid(ps, xs)
if mass > 1e-15:
new_centroids[i] = np.trapezoid(xs * ps, xs) / mass
else:
new_centroids[i] = (lo + hi) / 2.0
shift = np.max(np.abs(new_centroids - centroids))
centroids = new_centroids
if shift < tol:
break
return centroids, boundaries
# =============================================================================
# 2. Random Orthogonal Rotation Matrix
# =============================================================================
def generate_rotation_matrix(dimension: int, seed: int = 42) -> np.ndarray:
"""Generate a random orthogonal matrix via QR decomposition."""
rng = np.random.default_rng(seed)
G = rng.standard_normal((dimension, dimension))
Q, R = np.linalg.qr(G)
# Ensure a proper rotation (det = +1 convention, though sign doesn't matter
# for quantization purposes since we apply both Pi and Pi^T)
diag_signs = np.sign(np.diag(R))
diag_signs[diag_signs == 0] = 1.0
Q = Q * diag_signs[np.newaxis, :]
return Q
# =============================================================================
# 3. TurboQuant Core
# =============================================================================
@dataclass
class TurboQuantConfig:
"""Configuration for TurboQuant compression."""
dimension: int # Vector dimension (d)
bits: int = 4 # Bit-width for PolarQuant stage (typically 2-4)
qjl_enabled: bool = True # Whether to apply QJL residual correction
qjl_dim: int = 0 # QJL projection dimension (0 = auto = dimension)
seed: int = 42 # Random seed for reproducibility
@dataclass
class CompressedVector:
"""A TurboQuant-compressed vector."""
indices: np.ndarray # Quantization indices, shape (d,), dtype uint8/uint16
qjl_signs: Optional[np.ndarray] = None # 1-bit QJL signs of residual, shape (qjl_dim,)
class TurboQuant:
"""
TurboQuant compressor/decompressor.
Two variants:
- TurboQuant_mse: Stage 1 only (PolarQuant). Minimizes MSE.
- TurboQuant_prod: Stage 1 + Stage 2 (QJL). Unbiased inner products.
"""
def __init__(self, config: TurboQuantConfig):
self.config = config
d = config.dimension
# Build Lloyd-Max codebook (one-time, offline)
print(f" Building Lloyd-Max codebook (d={d}, bits={config.bits})...", end="", flush=True)
self.centroids, self.boundaries = build_lloyd_max_codebook(d, config.bits)
print(" done.")
# Pre-compute rotation matrix
self.Pi = generate_rotation_matrix(d, seed=config.seed)
# QJL random projection matrix (if enabled)
if config.qjl_enabled:
qjl_dim = config.qjl_dim if config.qjl_dim > 0 else d
rng = np.random.default_rng(config.seed + 1)
# Rademacher random matrix (entries ±1, unscaled)
self.S = rng.choice([-1.0, 1.0], size=(qjl_dim, d))
self.qjl_dim = qjl_dim
else:
self.S = None
self.qjl_dim = 0
def _quantize_scalar(self, values: np.ndarray) -> np.ndarray:
"""Quantize each scalar to the nearest centroid using boundaries."""
indices = np.searchsorted(self.boundaries[1:-1], values).astype(np.int32)
return indices
def _dequantize_scalar(self, indices: np.ndarray) -> np.ndarray:
"""Map indices back to centroid values."""
return self.centroids[indices]
def compress(self, x: np.ndarray) -> CompressedVector:
"""
Compress a vector x ∈ R^d.
Stage 1 (PolarQuant):
- Rotate: y = Pi @ x
- Quantize each coordinate of y using Lloyd-Max quantizer
Stage 2 (QJL, optional):
- Compute residual: r = y - dequant(quant(y))
- Store sign(S @ r) as 1-bit QJL correction
"""
# Normalize (TurboQuant operates on unit vectors; store norm separately if needed)
# For KV cache, vectors are typically not unit-norm, so we'd store the norm.
# Here we keep it simple and work with the raw vector.
# Stage 1: Rotate and quantize
y = self.Pi @ x
indices = self._quantize_scalar(y)
# Stage 2: QJL on residual
qjl_signs = None
if self.config.qjl_enabled and self.S is not None:
y_hat = self._dequantize_scalar(indices)
residual = y - y_hat
projected = self.S @ residual
qjl_signs = (projected >= 0).astype(np.int8) # 1-bit signs
return CompressedVector(indices=indices, qjl_signs=qjl_signs)
def decompress(self, compressed: CompressedVector) -> np.ndarray:
"""
Decompress (Stage 1 only — returns PolarQuant reconstruction).
For inner product estimation, use `inner_product` instead.
"""
y_hat = self._dequantize_scalar(compressed.indices)
x_hat = self.Pi.T @ y_hat
return x_hat
def inner_product(
self,
query: np.ndarray,
compressed_key: CompressedVector,
) -> float:
"""
Estimate <query, key> using TurboQuant_prod (unbiased).
If QJL is enabled, uses the two-stage estimator:
<q, k> ≈ <q, k_hat> + QJL_correction
where k_hat is the PolarQuant reconstruction and the correction
uses the 1-bit QJL signs of the residual.
"""
# PolarQuant reconstruction
y_hat = self._dequantize_scalar(compressed_key.indices)
k_hat = self.Pi.T @ y_hat
base_ip = np.dot(query, k_hat)
if not self.config.qjl_enabled or compressed_key.qjl_signs is None:
return base_ip
# QJL correction: estimate <q_rotated, residual> from signs
# From QJL theory (AAAI 2025):
# For iid Rademacher vectors r_1,...,r_m ∈ {±1}^d:
# <a, b> ≈ (sqrt(π/2) / m) * Σ_i sign(<r_i, a>) * |<r_i, b>|
# Here a = residual (compressed to signs), b = q_rotated (full precision)
q_rotated = self.Pi @ query
signs_pm1 = 2.0 * compressed_key.qjl_signs.astype(np.float64) - 1.0
m = self.qjl_dim
sq = self.S @ q_rotated # shape (m,): r_i^T q for each i
correction = (np.sqrt(np.pi / 2.0) / m) * np.dot(np.abs(sq), signs_pm1)
return base_ip + correction
def compression_ratio(self) -> float:
"""Compute the compression ratio (original bits / compressed bits)."""
original_bits = self.config.dimension * 32 # float32
compressed_bits = self.config.dimension * self.config.bits
if self.config.qjl_enabled:
compressed_bits += self.qjl_dim * 1 # 1 bit per QJL dimension
return original_bits / compressed_bits
def compressed_size_bytes(self) -> int:
"""Size of one compressed vector in bytes."""
# PolarQuant indices
index_bits = self.config.dimension * self.config.bits
# QJL signs
qjl_bits = self.qjl_dim if self.config.qjl_enabled else 0
return (index_bits + qjl_bits + 7) // 8 # round up to bytes
# =============================================================================
# 4. Batch Operations (for KV cache simulation)
# =============================================================================
class TurboQuantKVCache:
"""
Simulates a KV cache compressed with TurboQuant.
Stores compressed key and value vectors and supports attention computation.
"""
def __init__(self, config: TurboQuantConfig):
self.tq = TurboQuant(config)
self.compressed_keys: list[CompressedVector] = []
self.compressed_values: list[CompressedVector] = []
def append(self, key: np.ndarray, value: np.ndarray):
"""Add a new KV pair to the cache."""
self.compressed_keys.append(self.tq.compress(key))
self.compressed_values.append(self.tq.compress(value))
def attention_scores(self, query: np.ndarray) -> np.ndarray:
"""Compute attention logits <query, key_i> for all cached keys."""
scores = np.array([
self.tq.inner_product(query, ck) for ck in self.compressed_keys
])
return scores
def __len__(self):
return len(self.compressed_keys)
# =============================================================================
# 5. Self-Test & Demo
# =============================================================================
def run_self_test():
print("=" * 70)
print("TurboQuant — Self-Test & Demo")
print("=" * 70)
d = 256
n_vectors = 64
rng = np.random.default_rng(0)
# Generate random unit vectors
X = rng.standard_normal((n_vectors, d))
X = X / np.linalg.norm(X, axis=1, keepdims=True)
for bits in [2, 3, 4]:
print(f"\n{'─' * 60}")
print(f" TurboQuant d={d} bits={bits} n={n_vectors}")
print(f"{'─' * 60}")
# --- TurboQuant_mse (Stage 1 only) ---
config_mse = TurboQuantConfig(dimension=d, bits=bits, qjl_enabled=False)
tq_mse = TurboQuant(config_mse)
mse_list = []
cos_list = []
t0 = time.perf_counter()
for i in range(n_vectors):
c = tq_mse.compress(X[i])
x_hat = tq_mse.decompress(c)
mse_list.append(np.mean((X[i] - x_hat) ** 2))
cos_sim = np.dot(X[i], x_hat) / (np.linalg.norm(X[i]) * np.linalg.norm(x_hat) + 1e-15)
cos_list.append(cos_sim)
elapsed_mse = time.perf_counter() - t0
orig_bytes = n_vectors * d * 4 # float32
comp_bytes = n_vectors * tq_mse.compressed_size_bytes()
print(f" [MSE variant]")
print(f" MSE: {np.mean(mse_list):.6f}")
print(f" Mean cosine sim: {np.mean(cos_list):.4f}")
print(f" Size: {orig_bytes:,} → {comp_bytes:,} bytes "
f"({orig_bytes / comp_bytes:.1f}x)")
print(f" Time: {elapsed_mse * 1000:.1f} ms")
# --- TurboQuant_prod (Stage 1 + QJL) ---
config_prod = TurboQuantConfig(dimension=d, bits=bits, qjl_enabled=True)
tq_prod = TurboQuant(config_prod)
compressed_vecs = []
t0 = time.perf_counter()
for i in range(n_vectors):
compressed_vecs.append(tq_prod.compress(X[i]))
elapsed_comp = time.perf_counter() - t0
# Measure inner product accuracy on random query-key pairs
n_pairs = 500
true_ips = []
est_ips = []
for _ in range(n_pairs):
qi = rng.integers(0, n_vectors)
ki = rng.integers(0, n_vectors)
true_ip = np.dot(X[qi], X[ki])
est_ip = tq_prod.inner_product(X[qi], compressed_vecs[ki])
true_ips.append(true_ip)
est_ips.append(est_ip)
true_ips = np.array(true_ips)
est_ips = np.array(est_ips)
ip_rmse = np.sqrt(np.mean((est_ips - true_ips) ** 2))
ip_corr = np.corrcoef(true_ips, est_ips)[0, 1]
comp_bytes_prod = n_vectors * tq_prod.compressed_size_bytes()
print(f" [Prod variant — with QJL]")
print(f" IP RMSE: {ip_rmse:.6f}")
print(f" IP correlation: {ip_corr:.4f}")
print(f" Size: {orig_bytes:,} → {comp_bytes_prod:,} bytes "
f"({orig_bytes / comp_bytes_prod:.1f}x)")
print(f" Compress time: {elapsed_comp * 1000:.1f} ms")
# --- KV Cache simulation ---
print(f"\n{'=' * 70}")
print("KV Cache Simulation")
print(f"{'=' * 70}")
d_kv = 128
seq_len = 32
config_kv = TurboQuantConfig(dimension=d_kv, bits=3, qjl_enabled=True)
cache = TurboQuantKVCache(config_kv)
keys = rng.standard_normal((seq_len, d_kv))
keys = keys / np.linalg.norm(keys, axis=1, keepdims=True)
values = rng.standard_normal((seq_len, d_kv))
values = values / np.linalg.norm(values, axis=1, keepdims=True)
for t in range(seq_len):
cache.append(keys[t], values[t])
query = rng.standard_normal(d_kv)
query = query / np.linalg.norm(query)
scores = cache.attention_scores(query)
true_scores = keys @ query
score_corr = np.corrcoef(scores, true_scores)[0, 1]
print(f" Sequence length: {seq_len}")
print(f" Head dim: {d_kv}")
print(f" Bits per coordinate: 3 + 1 (QJL)")
print(f" Attention score corr: {score_corr:.4f}")
print(f" Compression ratio: {cache.tq.compression_ratio():.1f}x")
uncompressed_kv_bytes = 2 * seq_len * d_kv * 4 # keys + values, float32
compressed_kv_bytes = 2 * seq_len * cache.tq.compressed_size_bytes()
print(f" KV memory: {uncompressed_kv_bytes:,} → {compressed_kv_bytes:,} bytes "
f"({uncompressed_kv_bytes / compressed_kv_bytes:.1f}x)")
print(f"\n{'=' * 70}")
print("All tests passed ✓")
print(f"{'=' * 70}")
if __name__ == "__main__":
run_self_test()