Status. Architectural spec — first pass. Hand-off target: Claude app, downstream implementation planning. Scope. This document describes what nitrix is, what belongs in it, what does not, and the contract it offers upstream libraries (
thrux,bitsjax,nimox) and downstream consumers (ilex,entense). Implementation pseudo-code is intentionally absent.
nitrix is the lowest-level numerical substrate of the diffprog neuroimaging ecosystem. It is a pure-numeric, all-JAX library: every public symbol takes JAX arrays and returns JAX arrays. nitrix has no knowledge of image containers, sidecar metadata, BIDS, filesystems, training loops, or PyTree modules. Those concerns belong to other libraries that depend on nitrix.
The long-term vision is a substrate of differentiable numerical primitives sufficient to build software in the class of FSL / FreeSurfer / ANTs / AFNI on GPU-accelerated hardware. The shorter-term vision is the union of:
- Marquee items (§3): keops-like semiring ops; sparsity for brain geometries; multi-criterion smoothing; LME (stretch); essentials for the rest of the ecosystem.
- Foundational primitives (§4) needed by
ilexmodels andentense/bitsjaxtransforms.
- No NIfTI / GIfTI / CIfTI I/O — that is
thrux. - No transform / pipeline / dataset abstractions — that is
bitsjaxandentense. - No Equinox modules — those are
nimox. - No template / atlas registration as user-facing API (low-level primitives only — atlas
data structures and template-aware operations live in
thrux).
- Pure functional. Every public symbol is a function
(Array, …) -> Array. No state, no PyTrees in the public API. (PyTree-shaped configs are acceptable as keyword args.) - All differentiable. Subgradients are explicit where appropriate; custom VJPs are
registered where numerical stability or efficiency requires it (precedent:
linalg/matrix.pysym2vec_*rules in the existing codebase). - JAX + Pallas, with fallbacks. Hardware-aware Pallas kernels (NVIDIA / TPU) for the marquee ops; pure JAX fallbacks always present. Backend selection deterministic and user-overridable.
- Typed at boundaries.
jaxtyping.Array/Float[Array, "..."]annotations on all public functions. No bareArray | NDArrayunions. - No transitive heavyweight deps.
nitrixmay import onlyjax,jaxtyping,numpy.scipy,nibabel,numpyro,equinox, etc. are NOT permitted at runtime import. Test deps are scoped totests/. - Stable kernels, breakable APIs. Until 1.0, the API is mutable, but kernel output must be numerically reproducible across releases (test against pinned references).
Each marquee item has its own subpackage with a clear surface. Status flags below describe intended state at first GA; current legacy / partial state lives in §6 (Migration).
Arbitrary-algebra reductions over matmul, convolution, and ELL-sparse adjacency contractions.
The full target surface includes the strict semirings (sum-product, tropical max-plus /
min-plus, log-sum-exp, Boolean) and "semiring-analogous" algebras whose (*) is not
strictly associative — most importantly the Euclidean algebra
(sqrt ∘ Σ ∘ (a − b)²) and its relatives. These power both linear-algebraic reductions
and distance-driven operations (k-NN, neighbourhood smoothing, geodesic propagation).
This generality is a deliberate design bet: a single performant kernel substrate covers
matmul, convolution, distance, graph algebra, and morphology.
The algebra is decomposed into a small pair of Protocols (informed by the
semiring_gemm.py brainstorm; see §11):
Semigroup— a singlecombine(a, b)broadcasting binary op (the(*)step).Monoid[S]—(init, update, merge, finalize)over a pytree stateS. The pytree state is critical: it lets numerically-stable online reductions (e.g. log-sum-exp's(running_max, sum_exp)carry; Welford-style variance; running-norm Euclidean) thread state through the K loop without materialising intermediates.Semiring— the frozen(monoid, semigroup, name)triple.
Pre-built algebras: real, tropical_max_plus, tropical_min_plus, log, euclidean,
boolean. User-defined Monoids and Semigroups compose freely.
def semiring_matmul(
A: Num[Array, "... m k"],
B: Num[Array, "... k n"],
*, semiring: Semiring = REAL,
backend: Backend = "auto",
) -> Num[Array, "... m n"]def semiring_conv(
x: Num[Array, "... *spatial c_in"],
k: Num[Array, "... *kspatial c_in c_out"],
*, semiring: Semiring = REAL,
stride: ..., padding: ..., dilation: ...,
backend: Backend = "auto",
) -> Num[Array, "... *spatial c_out"]def semiring_ell_matmul( # the central op for brain-geometry workloads
values: Num[Array, "m k_max"], # ELL values (m rows × at most k_max neighbours)
indices: Int[Array, "m k_max"],
B: Num[Array, "n n_cols"],
*, semiring: Semiring = REAL,
n_rows: int, # outer dim of the implicit M×N sparse matrix
backend: Backend = "auto",
) -> Num[Array, "m n_cols"]- KeOps-style streaming. The K-block iteration folds rank-1 outer combines directly
into the (BM, BN) accumulator; the (BM, BK, BN) value tensor is never materialised. This
keeps peak on-chip memory at
O(BM·BN + BM·BK + BK·BN)and makes algebras whose(*)is non-multiplicative (Euclidean, tropical, log) practical at scale. - No tensor-core /
dotprimitive. Tensor cores assume(*) = ×. We issue plain CUDA-core / TPU SIMD ops so the same kernel codegens across all algebras. For the real semiring on hardware that can use tensor cores, an optionalbackend="tensor_core"specialisation falls back tojnp.matmul; this is a thin fast path, not the primary surface. - Pytree accumulator. Monoid state is a pytree;
lax.fori_loopover K threads the state through.finalizeapplies once at the end (e.g. Euclidean'ssqrt, log'sm + log(s)). - No BCOO. ELL is the primary sparse format (see §3.2). The ELL kernel walks the per-row neighbour list via gather + the same Monoid/Semigroup glue — Pallas-friendly, no jaxlib-sparse adversarial path.
pallas-cuda (default on NVIDIA) and pallas-tpu (default on TPU); JAX fallback (built
on lax.fori_loop + the same reference_semiring_gemm-style algebra plumbing) for CPU
and for shapes / algebras the Pallas builder cannot tile cleanly. Backend selection per
§7.2.
The same machinery underpins: graph path algebras (tropical), softmax / attention-style reductions (log), k-NN and bilateral search (Euclidean as the inner step of a nearest-neighbour scan), morphological opening / closing (tropical conv), spherical convolution on mesh sparsity (real or weighted on ELL adjacency), and binary connectivity analysis (Boolean). Subsequent §3 items (sparse, smoothing, morphology) are largely specialisations of this surface, not parallel implementations.
Sparsity primitives for the structures we actually care about in neuroimaging. ELL is the primary format — brain-geometry adjacencies (mesh k-rings, deformed icospheres, distance-thresholded neighbourhoods, atlas parcel members) are naturally fixed-degree-per-row, so ELL captures them losslessly with zero padding overhead in the common case and a single padded-row dimension in the worst.
Submodules:
sparse.ell— ELL format primitives (construction, reshape, gather/scatter, padding with semiring identity, batch broadcasting). The format is a thin pair of arrays(values: [m, k_max], indices: [m, k_max])plus a row-count and an algebra-identity for pad positions. No jax-sparse BCOO. The historical BCOO-based path inhypercoil/functional/sparse.pyhas been a persistent friction surface against the XLA / Pallas boundary; we implement on plain dense arrays + gather primitives so the kernels integrate cleanly.sparse.grid— regular-grid sparsity (low-bandwidth band matrices, stencil ops). The thin specialisation of ELL where every row has the same neighbour offsets.sparse.mesh— icosphere / deformed-icosphere mesh sparsity built atopsparse.ell: k-ring adjacency, sparse Laplacians, geodesic neighbourhoods.
The semiring kernels in §3.1 operate directly on ELL representations
(semiring_ell_matmul, semiring_ell_conv). Treat ELL + semiring as a single
co-designed pair: they are the substrate for both linear ops on graph adjacency and
distance-driven ops over k-NN graphs.
Edge-preserving, multi-channel, multi-domain smoothing centred on the permutohedral
lattice. Multi-domain means the feature space can mix space, intensity, gradient
direction, time, etc. — permutohedral handles arbitrary d_f in expected linear time and
subsumes the bilateral / trilateral / cross-bilateral special cases that SUSAN, Gaussian,
and friends individually cover.
def permutohedral_lattice(
values: Float[Array, "n d_v"],
features: Float[Array, "n d_f"],
*, sigma_features: Float[Array, "d_f"],
) -> Float[Array, "n d_v"]Plus a baseline gaussian (for cases where edge preservation is not wanted) and a
bilateral thin wrapper (the canonical d_f = d_space + d_intensity configuration).
SUSAN is intentionally not part of the public surface: its USAN/edge-preservation behaviour is recovered by feeding intensity (and, optionally, intensity gradient) into the permutohedral feature space. Skipping SUSAN cuts implementation scope without giving up capability.
Binary and grayscale erode / dilate / open / close, distance transforms. Implemented as
specialisations of semiring.conv with TROPICAL_MIN / TROPICAL_MAX. Listed separately
because it is a major user-facing surface, not because it has independent implementation.
Efficient voxelwise linear mixed-effects fits. Out of scope for first GA; the spec reserves the namespace and documents the API shape so downstream consumers can plan around it.
def voxelwise_lme(
Y: Float[Array, "n_obs *voxels"],
X: Float[Array, "n_obs p"], # fixed effects design
Z: Float[Array, "n_obs q"], # random effects design
groups: Int[Array, "n_obs"],
*, method: Literal["reml", "ml"] = "reml",
) -> LMEResult # NamedTuple of arrays — NOT a PyTree moduleOpen: solver choice (closed-form on small q vs iterative); whether to expose Henderson's mixed-model equations as a separate primitive.
These are the lower-glamour but high-traffic operations that ilex models, entense transforms, nimox modules, and bitsjax resolvers actually call. Most exist (in some form) in the current nitrix or in hypercoil; the migration map (§6) details origins.
matrix.py— symmetric /sym2vec/vec2sym/squareform/toeplitz/toeplitz_2d/delete_diagonal/fill_diagonal/diag_embed/recondition_eigenspaces. (Existing nitrix is healthy here.)spd.py— SPD manifold:symexp,symlog,symmap,symsqrt; tangent-space project / unproject (BatchTangentProjectnumerics, not the module). Stability rewrite required — current hypercoil implementation is flagged numerically unstable.kernel.py— parameterised kernels (linear, polynomial, Gaussian, RBF, Laplace) with single-dispatch over input matrix type. Includes initialisation helpers (Laplace, Toeplitz) folded in from legacyhypercoil/init/.residual.py—residualise(L2-regularised least squares). Existing; keep, fix the off-diagonal-weight gap fromcovariance.pywhile we are here.
covariance.py—cov,corr,partialcov,partialcorr,pairedcov,pairedcorr,conditionalcov,conditionalcorr,precision,corrnorm. Bug fix required: non-diagonal weight matrices currently silently produce wrong answers post-JIT (covariance.py:719–726). Either implement properly or raise unambiguously.fourier.py—product_filter,product_filtfilt,analytic_signal,hilbert_transform,envelope,instantaneous_phase,instantaneous_frequency. (Existing; keep.)lme.py— STRETCH (§3.5).
window.py—sample_windows(existing). Dropnumpyrodependency — usejax.randomdirectly for multinomial sampling.filter.py— FIR / IIR / frequency-domain filters. Pure-numeric implementations, separated from the module-shapednn/freqfilter.pyandnn/iirfilter.pyin hypercoil.tsconv.py— basis / polynomial / time-series convolutions (port ofhypercoil/functional/tsconv.py).interpolate.py— spectral / linear / hybrid interpolation for missing data (extract numerics fromhypercoil/functional/interpolate.py).normalize.py—intensity_normalize(min / p99 / clip) and friends. Migrates out ofilex/models/synthstrip/preprocessing.py.
grid.py—cmass_regular_grid,identity_grid,spatial_transform,vec_int,rescale(migratesilex/models/voxelmorph/_numerical.py; folds in existing nitrixgeom.pygrid bits).sphere.py— icosphere generation, spherical geodesics, sphere-to-normals / latlong, spherical convolution. Existinggeom.py+hypercoil/functional/sphere.py. Spherical conv is re-backed bysemiring.convover mesh sparsity (§3.1–3.2) — drop the legacy O(N²) inner loop.coords.py— coordinate utilities (cmass_coor,cmass_reference_displacement_*, spherical ↔ Cartesian).metrictensor.py— metric-tensor primitives (port from hypercoil).
laplacian.py— graph / modularity Laplacian, modularity matrix, Girvan–Newman null, coaffiliation.connectopy.py— eigenmaps, diffusion maps (extract fromhypercoil/functional/, decouple from brainspace).community.py— community / relaxed-modularity numerics.
tensor_ops.py—transpose,reshape_to,transpose_tf_conv_kernel,broadcast_bias, etc. The pure-array half ofilex/core/adapters.py(the adapter registry stays in ilex).
Internal utilities (axis manipulation, mask helpers, complex-number decompose / recompose, docstring formatters). Existing; healthy.
jax,jax.numpy,jax.experimental.pallasjaxtypingnumpy(for type aliases only)
equinox,quax— modules are upstream concernsnumpyro— currently violated bywindow.py; fix on migrationscipy,sklearn,pingouin— test-onlynibabel,templateflow,lytemaps— container-level, lives inthruxhypercoil,ilex,entense,thrux,bitsjax,nimox,conveyant,gramform,paranox- the standard library beyond what is needed for typing
thruximports nitrix and wraps its kernels in container-aware raise / lower pairs.bitsjaximports nitrix (and thrux) and packages ops as tensorbids operators / resolver steps.nimoximports nitrix and wraps primitives in Equinox PyTree modules.ilexandentenseimport the above, not nitrix directly (except where a model's internal numerics are pure-tensor; this is allowed but discouraged — prefer going through nimox / bitsjax).
The detailed source-by-source action list lives in MIGRATION.md. Summary by destination:
| Destination subpkg | Sources |
|---|---|
nitrix.semiring |
NEW — design from the semiring_gemm.py brainstorm (§11); no legacy port |
nitrix.sparse |
NEW. Do not port hypercoil/functional/sparse.py — historical BCOO friction. Re-implement on plain JAX gather + dense arrays. ELL primary; grid / mesh as ELL specialisations |
nitrix.smoothing |
NEW. Gaussian baseline can fold in the kernel from existing geom.py; permutohedral is clean-room. SUSAN dropped |
nitrix.morphology |
NEW — built atop semiring |
nitrix.linalg |
existing nitrix matrix.py, residual.py + hypercoil functional/{matrix, kernel, symmap, semidefinite, metrictensor} + init/{laplace, toeplitz, semidefinite} |
nitrix.stats |
existing covariance.py, fourier.py + hypercoil functional/cov consolidation; LME is NEW (stretch) |
nitrix.signal |
existing window.py (de-numpyro'd) + hypercoil functional/{tsconv, interpolate, fourier-bits} + ilex models/synthstrip/preprocessing.py intensity_normalize |
nitrix.geometry |
existing geom.py (split) + hypercoil functional/{sphere, cmass, metrictensor} + ilex models/voxelmorph/_numerical.py |
nitrix.graph |
hypercoil functional/{graph, connectopy, cmass} |
nitrix.numerics |
ilex core/adapters.py pure-array half (≈ lines 150–250) |
entense instance.py impls (polynomial_detrend_p, confound_regression_p) |
merge into nitrix.signal.filter / nitrix.linalg.residual |
Pallas kernels are an implementation detail behind the public API. The user-facing
function (semiring_matmul, permutohedral_lattice, etc.) chooses a backend; the kernel
file is private (_kernels/).
Three-level resolution: explicit backend= keyword → env var (NITRIX_BACKEND) →
auto-detect from jax.default_backend(). Auto-detect prefers pallas-cuda on NVIDIA,
pallas-tpu on TPU, jax fallback otherwise.
Every kernel has a JAX-side gradient. Pallas kernels register a jax.custom_vjp whose
backward is either (a) a paired Pallas kernel, or (b) a JAX fallback. Tests assert
forward / backward numerical agreement across backends (tolerance pinned per dtype).
- pytest, pinned numerical references (pingouin, scipy.ndimage, sklearn, communities) live
in
tests/and are not runtime deps. - Add hypothesis-based property tests for the marquee ops (associativity / identity for semirings; idempotence for morphological close-after-open at large kernels; etc.).
- Add backend-parity tests: same op via
pallas-cudaandjaxfallback must agree to pinned tolerance. - Add the JIT-trap regression test for
covariancewith non-diagonal weights — this is a known bug, not a feature.
Resolved (during this drafting pass):
Semiring representation— adopt the Monoid + Semigroup Protocol pair with pytree accumulator state (persemiring_gemm.pybrainstorm).Sparse format unification— ELL primary, dense JAX gather under the hood; no BCOO.Backwards compat— no legacy users; we break freely.
Deferred:
- Morphology placement. Independent subpackage (current spec) or buried inside
semiringas thin convenience wrappers? - LME scope. Voxelwise-independent (cheapest, plenty useful) vs voxelwise-with- spatial-regularisation (much harder, much more useful)?
- Kernel registry exposure. Should
linalg.kernelexpose raw kernels forthruxto wrap, or only high-level ops? numerics.reshapevsnumerics.tensor_ops. Submodule split-out granularity inside the small "uncategorised" area.- lytemaps. Does nitrix subsume lytemaps's JAX-compilable bits, or does lytemaps remain orthogonal (high-level wrappers around nitrix)? Recommended: orthogonal for now; revisit when nitrix.geometry.sphere matures.
- Tensor-core fast path. For the real semiring on hardware that can use tensor cores,
does the
backend="tensor_core"specialisation pay off enough to maintain? Or stay pure-Pallas everywhere for simplicity?
- All §4 foundational primitives implemented, tested, JAX + Pallas where applicable.
semiring.{matmul, conv, ell_matmul}shipped withreal,tropical_max_plus,tropical_min_plus,log,euclidean,booleanbuilt-in algebras, and a documented user-extension path (customMonoid/Semigroup).- KeOps-style streaming kernel passes parity tests against
reference_semiring_gemmand against naive broadcast formulations, with identity propagation (e.g.-infin tropical / log) and numerical stability (log with large magnitudes) regressions covered. sparse.{ell, grid, mesh}shipped;geometry.sphere.spherical_convre-backed bysemiring_ell_conv. Nojax.experimental.sparseBCOO dependency.smoothing.gaussian,smoothing.bilateral,smoothing.permutohedral_latticeshipped and tested against reference implementations. SUSAN intentionally absent.morphology.{binary, grayscale}shipped atop tropical-semiring conv.- All known bugs (covariance non-diag weights;
numpyroimport; 2D-onlyspatial_conv) resolved before any of the above land. - Downstream blockers for
ilex,entense,thrux,bitsjax,nimoxcleared. lmenamespace reserved; no implementation required.
- Semiring brainstorm (prior session, untested stubs, does not match house style):
_refstubs/semiring_gemm.py—Monoid/Semigroup/SemiringProtocols, pre-built algebras (real,tropical_*,log,euclidean), KeOps-style Pallas kernel builder, pure-JAX reference implementation._refstubs/test_semiring_gemm.py— parity tests against naive broadcast and the pure-JAX reference; identity-propagation and numerical-stability regressions. Treat as design input, not as a port target. Reimplement in the diffprog house style; preserve the Protocol shape, pytree-state pattern, and KeOps streaming idea.