Make KVTC compression/decompression run entirely on GPU using PyTorch tensor ops and Triton kernels. Target: <500ms total for 1024 tokens on RTX 5090 (currently 8-10 seconds).
- Pure Python nested loops: O(components × budget × max_bits)
- For dim=64, budget=358, max_bits=16: 64 × 358 × 17 = 389K iterations in Python
- FIX: Replace with closed-form greedy allocation. The DP is solving: minimize sum(λᵢ/4^bᵢ) subject to sum(bᵢ) ≤ B. This has a well-known greedy solution:
- Sort components by λᵢ (eigenvalue) descending
- Greedily assign bits to the component where adding 1 bit gives the biggest MSE reduction: Δ = λᵢ × (1/4^b - 1/4^(b+1)) = λᵢ × 3/(4^(b+1))
- Use a priority queue or just vectorized torch ops
- This is O(B × log(d)) instead of O(d × B × 16)
indices_list = [
uniform_quantize(pca_values[:, component], int(bit_widths[component].item()), ...)
for component in range(dim)
]- Iterates over each dimension in Python
- FIX: Batch quantize all components at once using vectorized torch ops:
- scales and zero_points are already tensors
indices = torch.round(pca_values / scales.unsqueeze(0) + zero_points.unsqueeze(0))- Then clamp per-component using bit_widths:
qmax = (1 << bit_widths) - 1 - Single tensor op instead of dim=64 separate calls
for idx, bits in enumerate(bit_widths.tolist()):- FIX: Vectorize entirely:
qmax = (1 << bit_widths.float()) - 1(handle 0-bit with mask)scales = (maxs - mins) / qmax.clamp(min=1)zero_points = -mins / scales
- Python loop over every value
- FIX: For fixed bit-widths per component, use torch bit-shift ops to pack. For variable widths, a Triton kernel.
- Already a matrix multiply:
pca_values = centered @ eigenvectors - Currently forces .cpu() — keep on GPU through the pipeline
- Only move to CPU for bit packing (final step)
Same pattern as compress — vectorize the dequantize.
Contains all GPU-optimized operations:
greedy_bit_allocation(eigenvalues, bit_budget, max_bits=16)— vectorized greedybatch_quantize(pca_values, bit_widths, scales, zero_points)— single tensor opbatch_dequantize(indices, bit_widths, scales, zero_points)— single tensor opvectorized_quant_params(pca_values, bit_widths)— no loopsfast_pack_bits(indices, bit_widths)— torch-based bit packing (CPU fallback ok)
- Import gpu_ops, use them instead of loop-based functions
- Keep tensors on GPU through PCA transform → quantize → only CPU for bit packing
- Add
deviceparameter to KVTCCompressor
- Don't delete — keep as CPU reference implementation
- Add per-stage timing breakdown
- Print speedup vs old pipeline
- Must pass all 38 existing tests (src/test_kvtc.py)
- Must produce identical compression ratios and cosine similarities (within float tolerance)
- Pure PyTorch + Triton — no custom CUDA C++ kernels
- Triton is optional (graceful fallback to torch ops if triton not available)
- Python 3.11, PyTorch 2.10, CUDA 12.8, RTX 5090
- Compression time < 500ms for 512 tokens on RTX 5090
- Same quality metrics (cosine sim within 0.001 of current)
- All 38 unit tests still pass
- Per-stage timing breakdown in benchmark output