A sparse, number-theoretic attention mechanism library for Lux.
Julia 1.11 required, 1.12 or higher recommended. May not work for Julia 1.12.4, as there are known issues with Julia 1.12.4 regarding Enzyme, which PrimeAttention depends on.
PrimeAttention.jl provides sparse attention layers that leverage number-theoretic sequences to define connectivity patterns. By using mathematical sequences instead of dense matrices or random sparsity, these layers achieve long-range dependencies with significantly lower computational complexity than standard
Inspired by architectures like BigBird, this package implements a hybrid mechanism that combines three strategies:
- Global Tokens: For sequence-wide context summaries.
- Sliding Window: For local syntax and immediate context.
- Theoretic Intervals: Sparse long-range connections based on Primes, Squares, or the Mian-Chowla sequence.
The package provides factory functions that return a SparseIndexAttention layer. Because this library is built on Lux.jl, it utilizes explicit parameter and state management.
Here is a simple example:
using Lux
using PrimeAttention
using Random
# Parameters
embed_dim = 64
n_heads = 4
global_tokens = 2
window_size = 3
# Initialize layer
attention_layer = PrimeSelfAttention(embed_dim; heads=n_heads, global_tokens=global_tokens, window=window_size)
rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, attention_layer)
# Dummy input
x = rand(Float32, embed_dim, 128, 8)
# Forward pass
y, st = attention_layer(x, ps, st)
println("Input: ", size(x))
println("Output: ", size(y))
# Lux Chain
model = Chain(
Dense(32 => 64),
PrimeSelfAttention(64; heads=4, global_tokens=2, window=3),
LayerNorm((64,)),
Dense(64 => 10)
)
ps_model, st_model = Lux.setup(rng, model)While not linear, this offers a significant speedup over standard
| Layer | Sequence | Density ( |
Complexity | Best For |
|---|---|---|---|---|
PrimeSelfAttention |
Primes |
Balanced long-range skip | ||
SquareSelfAttention |
Squares |
Extremely sparse / Efficiency | ||
MianChowlaSelfAttention |
Mian-Chowla |
Non-redundant "Sidon" patterns |
Based on the Prime Number Theorem, the density of connections follows
Connections are restricted to
The Mian-Chowla sequence is also known as the greedy Sidon set. In an attention context, this ensures that the "relative distances" between attended tokens are unique, theoretically reducing redundant information capture across sparse heads.
For a query token at index
Where
-
Architecture: Now using Lux.jl to leverage explicit parameter (
ps) and state (st) management, making it highly modular. -
Refactored Kernel: All layers share a universal, in-place
sparse_index_kernel!. Memory is strictly pre-allocated to guarantee type stability and prevent garbage collection bottlenecks during forward pass. -
Differentiation: Fully compatible with Enzyme.jl for fast, LLVM-level reverse-mode automatic differentiation.
-
Performance: While the asymptotic complexity is reduced mathematically, the current speedup is limited by CPU-bound scalar iteration. Using this on a GPU right now will trigger scalar indexing fallbacks. This implementation serves primarily as a research reference.
-
Future Work: The next update aims to rewrite the kernel using KernelAbstractions.jl or CUDA.jl to unlock true hardware-level GPU acceleration, followed by empirical performance benchmarks against standard dense attention algorithms.
