Add optional sparse covariance backend for matrix memory updates#117
Add optional sparse covariance backend for matrix memory updates#117simeon-kepp wants to merge 2 commits into
Conversation
… updates Introduce Sparse Ternary Covariance (STC) backend with optional write-skip acceleration. Includes ternary quantizer with STE, ternary gate module, and benchmarking suite.
|
All contributors have signed the CLA ✍️ ✅ |
|
I have read the CLA Document and I hereby sign the CLA |
|
Did you train any models with it? |
|
Not with xLSTM specifically, but we've validated ternary-thresholded sparse updates in our own training runs, a 17L MoE we're actively training uses the same core idea (adaptive {-1, 0, +1} quantization to skip low-magnitude outer-product writes). The skip pattern and fill rates we observe there are what motivated this PR. Happy to share training curves or benchmark numbers if useful for review. |
|
Following up on the offer to share numbers. The closest analogue in our training data is albert. — an 18L ternary MoE (12 experts, Top-3 routing, 256H, 32k vocab) currently at epoch 2572. The STE-trained ternary weights do produce real sparsity that scales with training depth: Layer-wise zero-weight fraction (TELE telemetry, ep2572):
This is weight-level sparsity after ~2500 epochs. The gradient flow through the STE zero region stays active (we track per-layer gradient norms and the zero-region pass-through is essential — without it, early layers freeze entirely). Zero-skip speedup (hardware-verified, x86 AVX2, i7-4800MQ):
Crossover point: empirically ~10% — below that, branch misprediction overhead on the skip check costs more than it saves (12–18% branch miss rates measured). The STC backend would want the EMA threshold tuned to stay above that floor in practice. For MoE specifically: at 75% routing sparsity (9/12 experts skipped per step), we measure 3.97× end-to-end MLP throughput with output divergence < 1e-4. The mLSTM outer-product skip is a different primitive but the same hardware phenomenon. Full benchmark suite (reproducible, §8–§11): https://github.com/eriirfos-eng/ternary-intelligence-stack/blob/main/ternlang-root/BENCHMARKS.md Happy to answer questions about the STE behaviour in the zero region specifically — that's where most of the practical tuning lives. |
What this adds
An optional sparse memory update path for the mLSTM cell, gated behind a config flag. Default behavior is unchanged.
The mLSTM cell's matrix memory update is compute-heavy at inference time: every step writes
C_new = fg * C_prev + ig * (k ⊗ v), regardless of how informative the current token is. In practice, a large fraction of key and value projections are near-zero — especially in later layers — meaning most of those outer-product writes are wasted.This PR adds a Sparse Ternary Covariance (STC) backend that skips low-magnitude writes by adaptively quantizing
kandvto{−1, 0, +1}before the outer product. Zero entries in the quantized vectors produce zero-contribution columns/rows in the update, which are skipped entirely. A ternary forget gate replaces the sigmoid gate when the STC backend is active, turning exact preserve (gate=0) and active inhibition (gate=−1) into first-class operations.Changes
New modules
xlstm/modules/ternary_quantizer.py— adaptive quantizer with Straight-Through Estimator for gradient flow through the zero regionxlstm/modules/ternary_gate.py— ternary forget gate logic ({−1, 0, +1})New kernels
xlstm/kernels/stc_sparse_update.py— PyTorch reference implementationxlstm/kernels/stc_sparse_update.cpp/.cu— C++/CUDA stubs ready for a hardware-accelerated write-skip kernelmLSTM integration
mLSTMCellConfig/mLSTMLayerConfig— two new opt-in fields:memory_backend("dense"|"stc_sparse") andgate_mode("sigmoid"|"ternary")mLSTMCell— quantizers and ternary gate wired in; dense path is untouchedbackends.py—recurrent_step_stabilized_simpleupdated to dispatch to the STC pathBenchmarks
bench/dense_vs_stc.py— wall-clock latency comparison (dense vs STC, warmup-corrected)bench/flops_saved.py— FLOPs analysis across sparsity levelsOpt-in usage
Default (
memory_backend="dense",gate_mode="sigmoid") is identical to current behavior. No existing tests should be affected.FLOPs at different sparsity levels
Sparsity levels depend on input distribution and the EMA threshold in the quantizer. The STE ensures gradients flow through the quantizer during training.
What's not included
Hardware-accelerated write-skip for the CUDA path — the
.custub is there but the sparse dispatch logic needs a proper CSC/COO kernel. Happy to implement that in a follow-up if the approach looks good.