Skip to content

LoveLow-Global/PrimeAttention.jl

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PrimeAttention.jl

Prime Attention Heatmap

A sparse, number-theoretic attention mechanism library for Lux.

Julia 1.11 required, 1.12 or higher recommended. May not work for Julia 1.12.4, as there are known issues with Julia 1.12.4 regarding Enzyme, which PrimeAttention depends on.

Overview

PrimeAttention.jl provides sparse attention layers that leverage number-theoretic sequences to define connectivity patterns. By using mathematical sequences instead of dense matrices or random sparsity, these layers achieve long-range dependencies with significantly lower computational complexity than standard $O(N^2)$ Transformers.

Inspired by architectures like BigBird, this package implements a hybrid mechanism that combines three strategies:

  1. Global Tokens: For sequence-wide context summaries.
  2. Sliding Window: For local syntax and immediate context.
  3. Theoretic Intervals: Sparse long-range connections based on Primes, Squares, or the Mian-Chowla sequence.

Usage

The package provides factory functions that return a SparseIndexAttention layer. Because this library is built on Lux.jl, it utilizes explicit parameter and state management.

Here is a simple example:

using Lux
using PrimeAttention
using Random

# Parameters
embed_dim = 64
n_heads = 4
global_tokens = 2
window_size = 3

# Initialize layer
attention_layer = PrimeSelfAttention(embed_dim; heads=n_heads, global_tokens=global_tokens, window=window_size)

rng = Random.default_rng()
Random.seed!(rng, 0)
ps, st = Lux.setup(rng, attention_layer)

# Dummy input
x = rand(Float32, embed_dim, 128, 8)

# Forward pass
y, st = attention_layer(x, ps, st)

println("Input:  ", size(x))
println("Output: ", size(y))

# Lux Chain
model = Chain(
    Dense(32 => 64),
    PrimeSelfAttention(64; heads=4, global_tokens=2, window=3), 
    LayerNorm((64,)),
    Dense(64 => 10)
)

ps_model, st_model = Lux.setup(rng, model)

Layer

While not linear, this offers a significant speedup over standard $O(N^2)$ attention.

Layer Sequence Density ($i$-th token) Complexity Best For
PrimeSelfAttention Primes ${2, 3, 5, \dots}$ $\approx 1/\ln i$ $O(N^2 / \ln N)$ Balanced long-range skip
SquareSelfAttention Squares ${1, 4, 9, \dots}$ $\approx 1/\sqrt{i}$ $O(N \sqrt{N})$ Extremely sparse / Efficiency
MianChowlaSelfAttention Mian-Chowla ${1, 2, 5, \dots}$ $< 1/\sqrt{i}$ $O(N \sqrt{N})$ Non-redundant "Sidon" patterns

Number-Theoretic Background

1. Prime Intervals

Based on the Prime Number Theorem, the density of connections follows $\pi(x) \approx x/\ln x$. This provides a fading resolution where recent history is dense and distant history is sparse.

2. Square Intervals

Connections are restricted to $j = i - n^2$. This is significantly sparser than Primes. Because the number of squares up to $N$ is exactly $\lfloor\sqrt{N}\rfloor$, the total complexity reduces from quadratic to sub-quadratic $O(N\sqrt{N})$.

3. Mian-Chowla (Greedy Sidon Set)

The Mian-Chowla sequence is also known as the greedy Sidon set. In an attention context, this ensures that the "relative distances" between attended tokens are unique, theoretically reducing redundant information capture across sparse heads.

Architecture

For a query token at index $i$ and a key token at index $j$, the causal mask is defined as

$$A_{i,j} = \begin{cases} 1 & \text{if } j \leq G & \text{(Global)} \\ 1 & \text{if } i - j \leq W & \text{(Window)} \\ 1 & \text{if } (i - j) \in \mathcal{S} & \text{(Sequence } \mathcal{S}\text{)} \\ 0 & \text{otherwise} \end{cases}$$

Where $G$ is the number of global tokens and $W$ is the window size.

Implementation Notes

  • Architecture: Now using Lux.jl to leverage explicit parameter (ps) and state (st) management, making it highly modular.

  • Refactored Kernel: All layers share a universal, in-place sparse_index_kernel!. Memory is strictly pre-allocated to guarantee type stability and prevent garbage collection bottlenecks during forward pass.

  • Differentiation: Fully compatible with Enzyme.jl for fast, LLVM-level reverse-mode automatic differentiation.

  • Performance: While the asymptotic complexity is reduced mathematically, the current speedup is limited by CPU-bound scalar iteration. Using this on a GPU right now will trigger scalar indexing fallbacks. This implementation serves primarily as a research reference.

  • Future Work: The next update aims to rewrite the kernel using KernelAbstractions.jl or CUDA.jl to unlock true hardware-level GPU acceleration, followed by empirical performance benchmarks against standard dense attention algorithms.

About

Sparse Attention using Number Theory

Resources

License

Stars

Watchers

Forks

Packages

 
 
 

Contributors

Languages