Scope. Reads on top of SPEC.md (v0) + SPEC_UPDATE.md (v0.1) + SPEC_UPDATE_v0.2.md
- MIGRATION.md. Where this plan conflicts with MIGRATION.md §6 (recommended order), this plan supersedes — see §1.2 below for rationale.
Audience. Whoever picks up implementation: human engineer, AI agent, or a rotating mix. The plan is written to survive both modes.
Tone. Prescriptive about contracts (what "done" means), permissive about sequencing (deviation is expected).
The plan is organised as phases with explicit entry / exit criteria, tracks
that run in parallel within phases, and gates that block progression. Tasks within
a track are ordered for typical execution but are not strictly serial unless marked
SERIAL.
If you are an agent (human or AI) starting work, read in this order:
- §1 (architecture of the plan) and §2 (non-negotiables and deviation protocol) — always.
- §3 (Phase 0) — always.
- The current phase, both tracks. Skim the next phase to anticipate dependencies.
- §9 (per-subpackage detailed checklists) when you start work on a specific subpackage.
- §10 (deviation log) before deciding whether a deviation request meets the bar.
When in doubt, §2.2 governs. That is the contract; everything else is guidance.
Phases are defined by what downstream consumers can build against at the end of each phase, not by which files exist on disk. This matters because deviation pressure comes from downstream consumers needing capabilities, not files.
| Phase | Capability delivered | Approx. duration | Gating risk |
|---|---|---|---|
| 0. Preflight | Known bugs fixed; performance baseline established; CI scaffolding | weeks | Ampere ELL gate decides Phase 2 shape |
| 1. Foundation | linalg, stats, signal consolidated; downstream can swap nitrix imports for hypercoil imports |
weeks | none — pure consolidation |
| 2. Substrate | semiring + sparse.ell shipped; downstream can build distance ops, graph algebra, attention-like reductions |
months | Triton stability; backward-kernel correctness |
| 3. Geometry & graph | geometry, graph shipped, spherical conv re-backed |
weeks | depends on Phase 2 |
| 4. Marquee | smoothing, morphology shipped; permutohedral resolved per tripwire |
weeks–months | permutohedral risk |
| 5. Polish | LME namespace reserved; docs; benchmark suite; first GA | weeks | none — pure cleanup |
The phases are not equal-sized. Phase 2 is the long pole. Phases 1 and 3 can compress or expand based on what's needed.
Within each phase (except Phase 0 and Phase 5), work splits across two tracks:
- Track A — Substrate. New code: the semiring kernel substrate, sparse format, smoothing kernels, morphology. This is where the design bets live and where the risk concentrates.
- Track B — Consolidation. Migration of existing hypercoil / nitrix / ilex /
entense numerics into the new layout. Mostly mechanical: namespace moves, deletion
of upstream deps, test consolidation. Lower risk, lower interest, high downstream
value (downstream libraries unblock by being able to import from
nitrix.linalginstead ofhypercoil.functional).
The two tracks share Phase 0 (preflight) and Phase 5 (polish). They converge at Phase 4 (smoothing depends on substrate from A and consolidated linalg from B). Otherwise they run independently.
Why this differs from MIGRATION.md §6. The migration doc puts substrate work first, then consolidation. That's wrong for two reasons: (a) the consolidated linalg / stats / signal modules are what downstream actually wants in the short term, and (b) substrate work has long uncertainty tails (Pallas / Triton risk, kernel correctness) that should not block low-risk consolidation. Running them in parallel halves the calendar time to "downstream can swap their imports."
A gate is a hard checkpoint: subsequent work depends on its outcome and proceeding without resolving it would invalidate downstream assumptions.
| Gate | Decides | Phase | §reference |
|---|---|---|---|
| G0. Ampere-ELL benchmark | Triton-default vs JAX-default for ELL kernels | end of Phase 0 | SPEC_UPDATE_v0.2 §4 |
| G1. Backward-kernel correctness | Whether each StrictSemiring backward passes finite-difference checks at pinned tolerance |
mid Phase 2 | SPEC_UPDATE §3.1 |
| G2. Permutohedral tripwire | Whether permutohedral_lattice ships or raises NotImplementedError |
mid Phase 4 | SPEC_UPDATE §3.3 |
| G3. Backend-fallback observability | Whether the warning infrastructure works under real fallback conditions | end of Phase 2 | SPEC_UPDATE_v0.2 §7.2 |
Gate outcomes are recorded in §10 deviation log even when they pass cleanly. A failed gate triggers replanning, not workarounds-in-place.
You told us downstream consumers occasionally need implementations urgently and the plan should accept some deviation. This section says what can flex and what can't. Without it, "deviation tolerance" becomes "the plan is decorative."
The following hold under all circumstances, including downstream pressure. An agent considering a deviation that violates any of these should refuse and escalate.
- Dependency contract (SPEC §5).
nitrixdoes not importequinox,quax,numpyro,scipy,nibabel, or anything upstream. A "quick fix" that re-introduces any of these is the bug, not the fix. If you find yourself wanting to, the capability belongs in a different library. - Pure-functional public API (SPEC §2.1). No PyTrees, no state, no module objects in the public surface. A downstream library can wrap nitrix functions in modules; nitrix does not ship modules.
- JAX fallback floor (SPEC_UPDATE_v0.2 §7.2). Every kernel has a working JAX fallback path exercised in CI. A Pallas-only kernel is not shippable.
- Golden corpus (SPEC_UPDATE §2.8). Every kernel × dtype × algebra combination that ships has a checked-in reference array. Adding a kernel without golden tests is not "done."
- Backward compatibility of kernel outputs across releases (SPEC §2.6). Once a kernel ships, changing its numerics within the pinned tolerance is a CHANGELOG entry. Changing it outside the tolerance is an API break.
- Loud fallbacks (SPEC_UPDATE §2.7). Backend fallback emits a warning. Silent fallback is a bug regardless of how it was introduced.
The following can flex when downstream urgency justifies it. Document the deviation in §10.
- Phase ordering. If a downstream library urgently needs a Phase 3 capability and Phase 2 is incomplete, build the Phase 3 capability on the JAX fallback path (skipping the semiring kernel optimisation) and revisit when Phase 2 catches up.
- Pallas vs JAX coverage on a specific kernel. Shipping a JAX-only kernel that meets correctness contracts is acceptable; we ship the perf later.
- Test depth on a specific path. Shipping a kernel with reduced coverage (e.g., one dtype instead of three, no Hypothesis tests yet) is acceptable if the golden corpus and backend-parity tests are present. Reduced coverage is logged as a known shortfall.
- Within-subpackage organisation. Whether
intensity_normalizelives innitrix.signal.normalizeornitrix.numerics.normalizeis a cosmetic concern; pick the one that unblocks fastest and revisit at GA. - Documentation completeness. Docstrings are required; tutorials and design notes can lag the implementation by a phase.
When a downstream library blocks on a capability not yet built:
- Identify the capability, not the implementation. "We need
bilateral_gaussianfor d_f=3" is a capability. "We neednitrix.smoothing.bilateral_gaussianon the Pallas backend with full test coverage" is an implementation. - Check the non-negotiables (§2.2). If the urgent path violates any, refuse and route to the design discussion.
- Find the minimum shippable shape that meets the capability and the non-negotiables. Often this is the JAX-only path, one dtype, with the standard golden test.
- Log the deviation in §10 with: the consumer who needed it, the capability shipped, the shape it shipped in, what's deferred, and when the deferred work is expected.
- Proceed. Don't wait for sign-off on the deviation if §2.2 holds; the non-negotiables are the sign-off.
The deviation log (§10) is the source of truth for "things we shipped under pressure that need follow-up." It is reviewed at every phase transition.
- Bypassing the golden corpus for "this one is small."
- Adding a runtime dependency for "just this one function."
- Shipping a kernel without a JAX fallback because "Pallas was easier."
- Skipping the backward kernel because "it's not needed yet."
These are the failure modes I want named explicitly because they are the natural "quick wins" that erode the architecture over the year.
Entry criteria. Spec accepted; this plan accepted.
Exit criteria. All §5 (MIGRATION.md) known issues resolved in-place; CI infrastructure in place; G0 Ampere-ELL benchmark result recorded.
Why this phase exists. Two reasons. First, MIGRATION.md §5 lists known bugs in the existing code that will silently corrupt the migration if not addressed first (particularly the JIT-trap covariance bug, which will produce wrong gradients in downstream optimisation if migrated as-is). Second, the Ampere-ELL benchmark (G0) determines the default backend for half the substrate work in Phase 2; running it first means Phase 2 doesn't get rewritten halfway through.
Resolve before the migration starts, so the migration begins from a clean base. These are not new TODOs; they are existing red flags.
covariance.py:719–726JIT trap with non-diagonal weights. Either fix or raise unambiguously at trace time. Add the regression test to the future golden corpus.covariance.py:683–686denominator gap. Same.window.py:12undeclared numpyro import. Drop. Usejax.random.categoricalorjax.random.multinomial. This is the only blocker fornitrix.signal.windowbeing migration-ready.geom.py:632spatial_conv2D-only TODO. Leave in place; this gets moot when re-backed onsemiring_ell_convin Phase 3. Mark the file with a clear "will be replaced" comment to prevent further investment.functional/geom.pynon-exported utilities (diffuse,cmass_reference_displacement_*). Export under their eventual destinations now if cheap; otherwise let Phase 3 handle.
- GPU runner provisioning. At least one Ampere runner (A100 or A40). Lovelace and Hopper coverage is a Phase 5 / 1.x target.
- Backend-parity test scaffolding. A pytest fixture parameterising over
(backend, dtype)cells. Initially empty (no kernels yet); the scaffolding has to exist before Phase 2 starts writing kernels. - Golden corpus scaffolding. Directory layout, naming convention, a
tests/golden/loader and atests/tolerance.tomlfile. Empty at first. NitrixBackendFallbackwarning category. Defined innitrix._internal, with a test that exercises the dedupe-by-(function, shape-signature, dtype, backend)logic.NITRIX_STRICT_BACKEND/NITRIX_SILENCE_FALLBACKenv vars plumbed through.
The benchmark gate. Implements two versions of semiring_ell_matmul on a representative
ELL workload (mesh adjacency, k_max ≈ 32, dimensions 1k–100k rows):
- Triton path. A skeletal Pallas kernel doing gather + accumulate with the real semiring. Doesn't need to be optimised, doesn't need full algebra coverage; needs to be representative of what the production kernel will do.
- JAX path.
jnp.take_along_axis+jnp.einsumor equivalent. The fallback this benchmark gates against.
Compare wall-time and compile-time on the reference A100 80 GB across a small grid of shapes. Decision:
- Triton < 2× JAX wall-time: Triton is default in Phase 2. Standard plan.
- Triton 2–5× slower: JAX is default on Ampere; Triton is opt-in via
backend="pallas-cuda". Update SPEC_UPDATE_v0.2 §4 footnote to reflect the actual measurement; document the reasoning in the kernel docstrings. - Triton ≥ 5× slower or unstable: Substantially harder conversation. Either (a) investigate Triton-side performance issues with the JAX team, (b) rewrite the kernel skeleton (different tiling, different gather pattern), or (c) accept that the streaming-kernel-substrate bet doesn't pay off on Ampere and reconsider the whole §3.1 design. This branch is unlikely but the plan should be honest that it's possible.
Output of G0: a checked-in benchmark report (Markdown), a checked-in benchmark script that can be re-run on new hardware, and a decision recorded in §10.
- Top-level
nitrix/package layout matching SPEC §3 and §4 (empty modules with docstrings). _kernels/cuda/directory for Pallas kernels (empty at this point)._refstubs/directory checked in but explicitly excluded from the package build. The semiring brainstorm lives there for reference; it does not ship.pyproject.tomlwith the minimal dependency set (jax, jaxtyping, numpy). Test deps (pingouin,scipy,sklearn,hypothesis) scoped to the test extra.
- All MIGRATION.md §5 issues either resolved in-place or moved to Phase 3 replacement plan with a clear note.
- Ampere GPU CI runner available and running an empty smoke test.
- Backend-parity, golden-corpus, fallback-warning scaffolding in place.
- G0 benchmark complete; decision logged in §10.
- Repository layout scaffolded,
_refstubs/excluded from build.
Entry criteria. Phase 0 exit checklist clear. G0 outcome known.
Exit criteria. nitrix.linalg, nitrix.stats, nitrix.signal, nitrix.numerics
consolidated and importable. Downstream libraries can replace
from hypercoil.functional import X with from nitrix.{linalg,stats,signal} import X
for the migrated surface.
Why this phase exists separately. This is the consolidation track (Track B in §1.2), running in parallel with Phase 2 substrate work. It's listed as its own phase only because some downstream libraries will not need substrate (Phase 2) capabilities and their unblock condition is this phase ending. Internally it's largely mechanical and can be parallelised across multiple agents / engineers.
These are independent and can be done in any order. Each one is a self-contained PR.
Consolidate current nitrix/functional/matrix.py with hypercoil/functional/matrix.py.
Preserve the custom VJP rules at the current matrix.py:554–568 as the pattern for
future work. Fold hypercoil/init/toeplitz.py initialisation in.
Consolidate current nitrix/functional/residual.py with hypercoil/functional/resid.py.
Fold the entense confound_regression_p numerical core in. The off-diagonal-weight
gap (Phase 0 fix) should already be addressed.
Port hypercoil/functional/kernel.py with the single-dispatch surface. Fold
hypercoil/init/laplace.py.
This one is SERIAL because it requires a numerical stability rewrite per SPEC §4.1.
Port hypercoil/functional/symmap.py and hypercoil/functional/semidefinite.py,
rewriting the SPD implementation for stability. The migration is the natural
opportunity per MIGRATION.md §5. Tests cover the stability regression explicitly.
Consolidate current nitrix/functional/covariance.py with hypercoil/functional/cov.py.
The JIT-trap fix from Phase 0 carries into this module.
Consolidate current nitrix/functional/fourier.py with hypercoil/functional/fourier.py.
Migrate current nitrix/functional/window.py (numpyro-stripped after Phase 0). Fold
hypercoil/functional/window.py.
Port hypercoil/functional/tsconv.py and hypercoil/functional/interpolate.py
(extracting numeric cores, dropping neuro context). Add nitrix.signal.filter with
the entense polynomial_detrend_p numeric core.
Port hypercoil/functional/linear.py to tensor_ops; port the ilex core/adapters.py
pure-array half (≈ lines 150–250) to tensor_ops. Port the ilex synthstrip
intensity_normalize to normalize.
At Phase 1 exit, downstream libraries get the following promise:
- All migrated public symbols are importable from their nitrix destination.
- Hypercoil import paths are deprecated but functional for a transition window (separate concern — the hypercoil-side shim is not nitrix's job, but the migration doc should track it).
- No behaviour change beyond the bug fixes in MIGRATION.md §5; numerics are pinned by the golden corpus.
- Every Phase 1 task above complete and golden-tested.
- Backend-parity tests pass for any kernel touching Pallas (mostly none in Phase
1; matmul-shaped ops in
linalgmight). - Downstream smoke test: one downstream library successfully imports from nitrix and runs a representative pipeline.
Entry criteria. Phase 0 exit checklist clear. G0 outcome incorporated into Phase 2 default-backend decisions. Phase 1 may be in progress (parallel track).
Exit criteria. nitrix.semiring with REAL, LOG, TROPICAL_MAX_PLUS,
TROPICAL_MIN_PLUS, BOOLEAN, EUCLIDEAN algebras shipped. nitrix.sparse.ell and
nitrix.sparse.ell.sectioned shipped. All §3.1 and §3.2 success criteria from the
spec met.
Why this phase exists. The marquee bet. Everything downstream of it (geometry mesh ops, smoothing, morphology) specialises onto this substrate. The phase is the long pole of the project.
Tasks within Phase 2 have real dependencies; the order matters more than in Phase 1.
2.A.1 Protocols (Semiring, StrictSemiring, Semigroup, Monoid)
│
├─→ 2.A.2 Reference JAX implementation (semiring_matmul, *_conv, *_ell_matmul)
│ │
│ └─→ 2.A.3 Built-in algebras (REAL, LOG, TROPICAL_*, BOOLEAN)
│ │
│ ├─→ 2.A.4 EUCLIDEAN (relaxed Semiring; first test of the relaxed path)
│ │
│ ├─→ 2.A.5 Backward kernels (per-algebra, JAX-side)
│ │ │
│ │ └─→ G1 — backward-kernel correctness gate
│ │
│ └─→ 2.A.6 Pallas/Triton kernel (if G0 allows default-Pallas)
│ │
│ └─→ 2.A.7 Pallas backward kernels
│
└─→ 2.A.8 nitrix.sparse.ell + sparse.ell.sectioned
│
└─→ 2.A.9 nitrix.sparse.grid, nitrix.sparse.mesh (specialisations)
│
└─→ G3 — fallback observability gate (real-world test case)
Define Semigroup, Monoid, Semiring, StrictSemiring in nitrix.semiring._types.
The StrictSemiring <: Semiring structural subtype with a strict=True constructor
flag. Type aliases exported from nitrix.semiring. Document associativity /
distributivity expectations per Protocol.
This is small and entirely API-design work; no kernels. Get the Protocol shape right here because changing it later is a public API break.
reference_semiring_gemm.py ported from the _refstubs brainstorm into the house
style. Implements semiring_matmul, semiring_conv, semiring_ell_matmul purely in
JAX via lax.fori_loop over the K block with the Monoid pytree state. No Pallas yet.
This is the correctness floor for everything in Phase 2. Every Pallas kernel that comes later is checked against this. It needs to be right before anything else gets built on top of it.
REAL, LOG, TROPICAL_MAX_PLUS, TROPICAL_MIN_PLUS, BOOLEAN. Each with:
init,update,merge,finalizeover the Monoid state.binary_opfor the(*)step.- Identity element.
- Golden test: forward output matches a naive broadcast formulation on small inputs.
- Golden test: identity propagation works correctly (e.g.,
-infinTROPICAL_MAX_PLUSannihilates). - Golden test: numerical stability under adversarial inputs (e.g.,
LOGwith magnitudes spanning ±1000).
The numerical-stability tests are the ones most likely to catch a subtle bug in the streaming-kernel state. Do not skip.
The first test of the relaxed Semiring (non-StrictSemiring) path. Validates that
the type-system distinction between strict and relaxed actually works, and that
algorithms gated on StrictSemiring reject the relaxed Euclidean as expected. Golden
tests include the sqrt singularity guard at zero.
Per the §3.1 backward vocabulary:
REAL: transpose-matmul (reuse forward kernel with swapped operands).LOG: softmax-weighted; the softmax is recomputed in the backward K loop, not materialised.TROPICAL_*: argmax/argmin gather, subgradient.EUCLIDEAN: normalised-difference with √-singularity guard.BOOLEAN: not differentiable; backward raises a clear error.
Each backward registered via jax.custom_vjp. Each passes finite-difference checks
at the pinned tolerance.
G1 gate. If any algebra's backward fails the finite-difference check at the pinned
tolerance, do not ship that algebra at Phase 2 exit. The algebra ships forward-only
with a documented gradient raise (matching the BOOLEAN pattern). This is a real
possibility for EUCLIDEAN near the √-singularity; the plan accepts that as a known
risk.
If G0 said Triton-default is viable: implement the Pallas / Triton kernel for
semiring_matmul, semiring_conv, semiring_ell_matmul. The kernel is parameterised
over the algebra via Monoid / Semigroup callables passed at kernel compile time.
If G0 said JAX-default: this task slips to Phase 5 / 1.x. The JAX kernel is the production path; the Pallas path is opt-in and may not exist at first GA. Substantial scope cut, but the streaming-kernel design still holds — just less of it ships in optimised form.
Per-algebra Pallas backward kernels, paired with their forwards. Same conditional as 2.A.6.
ELL format: (values, indices, n_rows, identity) with gather / scatter / pad / reshape
primitives. Plus nitrix.sparse.ell.sectioned (the bucketed-row variant for
variable-degree adjacencies). Per SPEC_UPDATE §3.2, sectioned-ELL is CORE.
semiring_ell_matmul and semiring_ell_conv accept either flat or sectioned ELL.
Thin specialisations of ELL. grid is the case where every row has the same neighbour
offsets (regular-grid stencils). mesh is icosphere k-ring adjacency, sparse
Laplacians, geodesic neighbourhoods. Mostly format-conversion code; the heavy lifting
is in semiring_ell_matmul.
A test that forces a shape × algebra combination Triton cannot tile (e.g., a k_max
larger than fits in shared memory). Asserts that the NitrixBackendFallback warning
fires exactly once per (function, shape-signature, dtype, backend) and that the JAX
path produces the correct answer. Asserts that NITRIX_STRICT_BACKEND=1 converts the
fallback to an error.
If G3 fails, the fallback infrastructure (Phase 0 scaffolding) is broken and needs fixing before Phase 2 ships.
At Phase 2 exit, downstream gets:
- Differentiable matmul, conv, ELL-matmul over any of the six built-in algebras.
- The substrate for any custom algebra (with user-supplied VJP for differentiability).
- ELL format primitives, including the sectioned variant for variable-degree cases.
- Backend selection (
auto,pallas-cuda,jax) with loud fallback.
- All §10 success criteria for §3.1 and §3.2 met (per SPEC + SPEC_UPDATE).
- G1 backward-kernel gate passed for at least
REAL,LOG,TROPICAL_*.EUCLIDEANandBOOLEANoutcomes recorded. - G3 fallback-observability gate passed.
- Golden corpus populated for every (kernel, dtype, algebra, backend) cell that ships.
- Backend-parity CI green.
Entry criteria. Phase 1 exit and Phase 2 exit both clear (geometry mesh ops depend on Phase 2 substrate; geometry grid ops depend on Phase 1 linalg).
Exit criteria. nitrix.geometry and nitrix.graph shipped. Spherical convolution
re-backed on semiring_ell_conv (the legacy O(N²) inner loop is gone).
Migrate ilex models/voxelmorph/_numerical.py (identity_grid, spatial_transform,
vec_int, rescale). Fold in current geom.py grid bits. Add cmass_regular_grid from
hypercoil cmass.py. Regression tests from voxelmorph travel with the code.
Migrate hypercoil sphere.py + current geom.py sphere bits. Re-back spherical
convolution on semiring_ell_conv over mesh adjacency. The legacy O(N²) inner loop
is dropped. This task validates the §3.1 design bet end-to-end: a spherical conv that
previously had its own bespoke implementation now specialises onto the substrate.
If G0 said JAX-default, the Pallas perf gain on spherical conv is deferred. The spherical conv still ships, just not as fast. That's acceptable.
Coordinate utilities from hypercoil cmass.py + current geom.py coords bits.
Includes the previously non-exported diffuse and cmass_reference_displacement_*
(MIGRATION.md §5).
Port hypercoil metrictensor.py.
Port hypercoil functional/graph.py (Laplacian, modularity, Girvan-Newman null).
Extract from hypercoil functional/connectopy.py. Strip the brainspace dependency.
This is the one place in Phase 3 where the "drop neuro context, keep numerics" rule
needs careful attention — the eigenmap / diffusion-map algorithms are general, but
the existing implementation may have brainspace assumptions baked in.
Port hypercoil community / relaxed-modularity numerics.
- Geometry and graph subpackages importable, golden-tested.
- Spherical conv: numerical agreement with the legacy O(N²) implementation to pinned tolerance.
- Brainspace dependency removed.
Entry criteria. Phase 2 exit clear. (Phase 3 is independent.)
Exit criteria. nitrix.smoothing (gaussian, bilateral_gaussian,
permutohedral_lattice per tripwire), nitrix.morphology shipped.
Separable Gaussian. Pure JAX. The unconditional baseline. Cheap.
Direct N-body bilateral over arbitrary d_f, implemented as a semiring_ell_matmul
over distance-thresholded sectioned-ELL adjacency. This is the marquee capability
delivered regardless of permutohedral risk.
Specific tests: agreement with a reference NumPy direct-N-body implementation;
agreement with gaussian in the limit of large sigma_intensity (where the
intensity-similarity weighting becomes uninformative).
Specialisations of semiring_conv with TROPICAL_MIN_PLUS / TROPICAL_MAX_PLUS.
Thin wrappers — most code is documentation. Distance transforms via the standard
two-pass min-plus algorithm.
Gather-based op, not a semiring op. Implemented as gather → jnp.median over the
neighbourhood. Parity test against scipy.ndimage.median_filter within pinned
tolerance.
Convenience wrapper composing bilateral_gaussian + median_filter. Docstring
explicitly documents the behavioural deltas from FSL SUSAN (no auto-flat-kernel at
small extents).
The high-risk item. Attempt the implementation in this order:
- Pure JAX reference. Working implementation of splat / blur / slice in JAX. Slow but correct. Reference for the optimised path.
- Optimisation pass. JAX with hand-rolled gather patterns, or JAX+Pallas hybrid for the splat/slice hash table operations. Pallas-pure is explicitly not required (SPEC_UPDATE §3.3).
- G2 tripwire evaluation. Per SPEC_UPDATE_v0.2 §3.3, evaluate against the four criteria (PSNR > 40 dB, < 10× Gaussian wall time, < 30 s first compile, gradient passes finite-diff). Pin the actual numbers from benchmark before evaluation.
G2 outcomes:
- All four criteria met: ship
permutohedral_lattice. Done. - Criteria 1 (parity) or 4 (gradient) fail: correctness problem; fix or revisit. Do not ship a permutohedral with wrong outputs or wrong gradients.
- Criteria 2 (perf) or 3 (compile) fail: the implementation is correct but
doesn't clear the perf bar. The symbol raises
NotImplementedErrorpointing tobilateral_gaussianfor d_f ≤ 5. The implementation is checked in under_experimental/for future work. Revisit at 1.x.
-
gaussian,bilateral_gaussianshipped unconditionally. - Morphology shipped: erode, dilate, open, close, distance_transform via tropical semiring; median_filter via gather; susan_emulator composing both.
- G2 evaluated;
permutohedral_latticeeither shipped or raises with clear pointer.
Entry criteria. Phases 1, 2, 3, 4 exit checklists clear.
Exit criteria. First GA released.
nitrix.stats.lmenamespace reserved. Stub module withNotImplementedErrorraises and a clear roadmap docstring. No implementation per SPEC §3.5.- Documentation pass. Every public symbol has a docstring including a one-line summary, signature in jaxtyping, example, and (where relevant) backend notes.
- Benchmark suite. Reusable benchmark scripts checked in under
bench/, covering the marquee operations on Ampere. Results checked in as Markdown reports per hardware generation. - Tutorials. A small set of "how to use the substrate" notebooks: writing a custom Semiring, using ELL for mesh ops, choosing between gaussian and bilateral_gaussian.
- Migration guide for downstream libraries. A separate document mapping
hypercoil.functional.Xandnitrix-old.Xto their new locations. - CHANGELOG and versioning policy. Pin the "stable kernels, breakable APIs" contract concretely.
- Release process documentation. How to cut a release, how to update the pinned jax minimum, how to add a new backend (future-proofing for TPU).
All of:
- Phases 0–4 exit checklists clear.
- All §10 success criteria from SPEC and SPEC_UPDATEs met.
- Backend-parity CI green on Ampere. Hopper / Blackwell coverage is 1.x.
- Golden corpus populated for every (kernel, dtype, algebra, backend) cell.
- Documentation pass complete.
- One downstream library (thrux, nimox, ilex, entense, or bitsjax) successfully uses nitrix end-to-end for a real workload.
This section is the reference for agents picking up work on a specific subpackage. It restates the relevant material from the spec in a "what does done look like" format, without duplicating spec text.
For brevity, the checklist below is abbreviated; the full per-symbol checklist lives
in the subpackage's own _PLAN.md (to be created in Phase 0 scaffolding).
- Protocols defined:
Semigroup,Monoid,Semiring,StrictSemiring. - Reference JAX
semiring_matmul,semiring_conv,semiring_ell_matmulworking. - Built-in algebras: REAL, LOG, TROPICAL_MAX_PLUS, TROPICAL_MIN_PLUS, BOOLEAN, EUCLIDEAN.
- Backward kernels per algebra, registered via
jax.custom_vjp. - Pallas / Triton kernel (if G0 viable).
- Golden corpus + backend parity + identity propagation + numerical stability tests.
- Documented user-extension path.
-
sparse.ellprimitives. -
sparse.ell.sectionedfor variable-degree. -
sparse.grid,sparse.meshas specialisations. - No
jax.experimental.sparseimport anywhere.
-
gaussianunconditional. -
bilateral_gaussianunconditional. -
permutohedral_latticeevaluated at G2; ships or raises with pointer.
- erode, dilate, open, close on tropical semiring.
- distance_transform via two-pass min-plus.
- median_filter via gather.
- susan_emulator composing bilateral_gaussian + median_filter.
- matrix, residual, kernel, spd consolidated.
- SPD numerical-stability rewrite complete.
- covariance with JIT-trap fix.
- fourier consolidated.
- lme namespace reserved (no implementation).
- window numpyro-stripped.
- filter, tsconv, interpolate ported.
- grid, sphere, coords, metrictensor.
- Spherical conv re-backed on semiring.
- laplacian, connectopy, community.
- No brainspace dependency.
- tensor_ops, normalize.
Maintain this section as work proceeds. Every deviation from the plan — both gate outcomes and shipped-under-pressure capabilities — gets a row.
### YYYY-MM-DD — Short title
- **Type:** Gate outcome | Downstream deviation | Plan revision
- **Triggered by:** (consumer, gate, or planning decision)
- **Description:** What happened.
- **Capability shipped:** (if deviation) What downstream got.
- **Shape:** (if deviation) JAX-only | reduced coverage | other
- **Deferred work:** What's still owed.
- **Expected resolution phase:** (if deviation)
- **Non-negotiables held:** Confirmation that §2.2 was respected.
### TBD — G0 Ampere ELL benchmark outcome
- **Type:** Gate outcome
- **Triggered by:** Phase 0 gate G0
- **Description:** (Benchmark result; Triton vs JAX wall-time.)
- **Decision:** Triton-default | JAX-default with Triton opt-in | Reconsider §3.1
- **Impact on Phase 2:** (Carry into 2.A.6 task scope.)
2026-05-20 — SUGAR feedback batch: edge attributes, row-softmax, mean-pool, external topology, masking
- Type: Downstream deviation
- Triggered by: ilex/SUGAR port (
NITRIX_FEEDBACK_ILEX.md, 2026-05-18) — second surface-domain consumer of the ELL mesh-graph-conv substrate after Topofit. - Description: Five additive, substrate-aligned changes. (1)
edge_attr=kwarg onsemiring_ell_edge_aggregate: when set,edge_fnreceives a 5th arga = edge_attr[i,p,:]while keeping the scalarw(the padding signal) — covers GATv2'sedge_dimFourier embedding. Refines the feedback's Option A (which would have displacedw); backward-compatible. (2)ell_row_softmax(scores, ell): GAT attention pre-pass, masking pads fromell.values == ell.identity(the feedback's first real consumer of this proposal). (3)mesh_coarsen_meanpool: mean-pool sibling ofmesh_pool_max;icosphere_cross_level_adjacencynow stores a 1.0/0.0 validity indicator invalues(identity 0.0) so mean falls out assum(v·x)/sum(v)—mesh_pool_maxoverrides values internally so it is unaffected. (4)icosphere_hierarchy_from_levels(meshes, parents): packages caller-supplied topology into the existingIcosphereHierarchy, so FreeSurferfsaveragehierarchies run through every cross-level operator with no topology-source branching. (5)ell_mask(ell, valid, *, identity): masks incomplete geometries (medial wall / grey-matter) by setting masked edges to the semiring identity (consumer-raised; see the masking note below). - Capability shipped: GATv2/edge-attributed mesh convs, GAT attention, surface mean-pooling, external (FreeSurfer) topology hierarchies, and masked reductions — all on the existing semiring/ELL substrate.
- Shape: Pure-JAX (the substrate's current state; ELL Pallas is gated on G0). Full forward/backward, golden + property tests, CPU correctness floor.
- Rejected (concern leakage): the feedback's Delta-3 options A/B that would
read FreeSurfer
.spherebinaries (nibabel,$SUBJECTS_DIR) inside nitrix. That violates SPEC §5.2 / non-negotiable §2.2.1. nitrix stays array-only; the consumer/thruxdoes the I/O and hands in plain arrays viaicosphere_hierarchy_from_levels. - Deferred work: Pallas dispatch for
semiring_ell_edge_aggregate(BACKLOG B3); LOG/EUCLIDEAN edge-aggregate semirings (B4). Bench at ico_6/ico_7 (B2). - Non-negotiables held: No new deps; pure-array signatures (NamedTuple/dataclass containers only); JAX floor exercised in CI; golden/property tests added.
- Type: Downstream deviation
- Triggered by: consumer question — medial-wall (surface) and grey-matter (volume) masks must make absent edges no-ops without blurring in masked signal.
- Description: Verified the substrate already supports this: a missing edge is a
no-op iff its
valuesentry is the algebra's(*)-annihilator, which equalssemiring.identityfor REAL (0), LOG/TROPICAL_MAX_PLUS (−∞), TROPICAL_MIN_PLUS (+∞), BOOLEAN (False) — and the no-op holds regardless of where the padded index points. EUCLIDEAN is the documented exception:(a−b)²has no annihilator, so EUCLIDEAN neighbourhoods must be masked by dropping columns structurally, not via a value. Shippednitrix.sparse.ell_mask(ell, valid, *, identity)(column- or edge-mask) plus a parametrised verification suite (tests/test_ell_masking_semirings.py) covering the no-op property, the EUCLIDEAN limitation, and the "wrong identity under max-plus leaks" footgun. Also made the four cross-level mesh wrappers batch-safe (they claimed(..., n, d)but only handled 2-D) via a shared vmap-over-leading-dims helper. - Capability shipped: correct, semiring-aware masking of incomplete brain geometries; honest batch support on the cross-level wrappers.
- Shape: Pure-JAX; differentiable; full parity tests.
- Non-negotiables held: array-only; no deps;
nitrix.sparse.ellstays free of anitrix.semiringimport (identity passed explicitly).
- Type: Plan revision / dependency hygiene
- Triggered by: the uv-managed
.venv(anduv.lock) had drifted to jax 0.4.35 while the Dockerfile / validated baseline is jax[cuda12]==0.10.0 (the G0 report and kernels were all developed against 0.10.0). The test suite had therefore been running on the wrong jax:jax.random.multinomial(used bysignal/window.py) is absent before 0.10.0, sosignal/window/lmefailed, andnumpyrocollection broke. - Root cause:
pyprojectdeclaredjax >= 0.4.30withrequires-python = ">=3.10"; jax 0.10.x dropped Python 3.10, so uv resolved down to the last 3.10-compatible jax (0.4.35). The nox matrix is already 3.11/3.12/3.13. - Fix: bumped the floor to
jax >= 0.10.0andrequires-python = ">=3.11", re-locked (jax/jaxlib pinned to 0.10.0 to match the Docker baseline), and bumped the test-onlynumpyro0.18.0 → 0.21.0 (0.18 importedjax.experimental.pjit.pjit_p, removed in jax 0.10). After the fix,signal/window/lme/geompass. - Non-negotiables held:
numpyroremains test-only (absent from the Docker runtime env; SPEC §5.2); no new runtime deps.
- Type: Robustness hardening
- Triggered by: the pin correction surfaced a hard XLA CPU compiler abort
(
AlgebraicSimplifier::HandleReverse: "Invalid binary instruction opcode map") while compilingtoeplitz_2d, which built its matrix withjnp.flip(areverseHLO). Worked on 0.4.35; crashes on 0.10.x CPU. - Fix: replaced
jnp.flip(c_arg, -1)with an index-based reverse (c_arg[..., jnp.arange(d-1, -1, -1)], a gather) in both copies (functional/matrix.py,linalg/matrix.py). Identical output (parity withscipy.linalg.toeplitzon square / rectangular-extend / fill cases), negligible cost, noreverseHLO → no crash.test_matrixpasses (14). - Note: the remaining 0.10.x suite failures (
test_util,test_resid) are hypothesis test-harness flakiness (FlakyFailurefrom overflow-y generated inputs;FailedHealthCheck: data generation extremely slowon a loaded box) in untouched property tests, not core bugs or version-API brittleness. Tracked as a separate test-quality item (constrain the strategies / add a CI hypothesis profile withdeadline=None+ health-check suppression). - Non-negotiables held: numerics unchanged within tolerance; no deps.
- Type: Gate outcome / plan revision
- Triggered by: consumer ask for a 3-D trilinear resampling Pallas kernel.
- Description: Trilinear resampling is structurally a gather (8 data-dependent
corner loads) — the same primitive G0 found Pallas Triton cannot lower on the
pinned JAX. Rather than write a kernel speculatively, shipped a baseline bench
(
bench/trilinear_resample.py→bench/PERF_TRILINEAR.md) and parked the kernel in BACKLOG B7 behind a two-part gate: (a) the path is a real training-loop bottleneck, and (b) a pointer-load Pallas prototype clears the gather-lowering risk. The Gaussian-blur Pallas request (low priority) is parked in B6 (stencil, not gather; cuDNN baseline is strong; only a fused-passes win exists). - Decision: JAX-default (current state) until the gate clears. No kernel shipped.
- Non-negotiables held:
map_coordinatesJAX path remains the contractual floor.
- Type: Plan revision / cleanup
- Triggered by:
functional/flagged as leftover legacy (already migrated). - Description: Removed
src/nitrix/functional/entirely. It was runtime-dead (nosrcimport; only legacy tests referenced it) and every symbol was migrated:covariance/fourier→stats,matrix/residual→linalg,window→signal,geom→geometry(renames:sphere_to_normals→latlong_to_cartesian,sphere_to_latlong→cartesian_to_latlong,spherical_geodesic→spherical_geodesic_distance). Legacy tests handled by coverage comparison (collected case counts):- Deleted
test_matrix(14),test_window(2),test_geom(17) — the newtest_linalg(29) /test_signal(6) /test_geometry(53) are supersets. - Deleted
test_cov— it tested the old covariance API (singleweightparam + private_prepare_*helpers) which the migration redesigned toweights=/weight_matrix=with new internals;test_statscovers the new API. Added a non-diagonal-weight_matrixregression totest_stats(the SPEC §8 mandate, now that the behaviour is compute-correctly, not raise). - Repointed
test_fourier→stats,test_resid→linalg. Verifiedresidualiseis numerically identical old vs new. The repoint surfaced two intended migration API changes (validating the "run to catch drift" instinct), adapted in the test:analytic_signalnow raisesTypeError(notValueError) on complex input and takesaxiskeyword-only.
- Deleted
- Non-negotiables held: no runtime deps added/removed; migrated impls unchanged; new modules + tests are the canonical coverage.
- Type: Test-quality / robustness
- Triggered by: the three originally-flaky tests (
test_util,test_geom,test_resid). - Description: Added
tests/conftest.pywith a hypothesis profile (deadline=None; suppresstoo_slow/data_too_large). JAX first-call JIT compile makes per-example deadlines unreliable (DeadlineExceeded->FlakyFailure); disabling them removes the timing flakiness suite-wide with zero input-space change — every example is still drawn and asserted. Relaxed the explicitdeadline=500intest_util.test_geom's flakiness (unseeded random + exact==0truncation boundary) is moot — it was deleted above (test_geometrycovers spherical conv). The deflake's fuller exploration unmasked a known, author-documentedresidualiselimitation: the exactresidual + projection == Ydecomposition breaks at1e-5(float32) for ill-conditioned designs (p -> obs). It is pre-existing (identical on old and newresidualise) — see BACKLOG B9. Per decision, constrained the exact-decomposition property tests to the well-conditioned domain (generate_valid_arrays(well_conditioned=True),p <= obs/2) so they are honest and green, and BACKLOG'd the real numerical fix (SVD/QR projector). - Non-negotiables held: the deflake loses no input coverage (only timing assertions dropped); the ill-conditioned limitation is documented + tracked (B9), not silently skipped.
- Type: Learning / future-API note
- Description:
Semiring.identityis the monoid identity; padding / masking (sparse.ell_mask) needs the(*)-annihilator, which coincides withidentityfor all built-ins exceptEUCLIDEAN(no annihilator;identity=0does not mask). Recorded indocs/design/semiring-protocols.mdand BACKLOG B8 (consider an explicitannihilatorfield rather than overloadingidentity).
- Type: Plan revision (new gate)
- Triggered by: review found many
nitrixfunctions un-/under-typed; aligns the surface to thethruxstatic-typing standard. - Description: Lifted
[tool.mypy]to the thrux bar (disallow_untyped_defs,disallow_incomplete_defs,warn_unused_ignoreson top of the existing strict base) and added atypechecknox session runningmypy src/nitrix(now innox.options.sessions). The base config already existed but was never run -- 136 latent errors under the old settings, 332 under the new. Drovemypy src/nitrixto 0 errors across 65 files: every def annotated; jaxtyping-native array types throughout (Float/Num/Int/Bool/Shaped/Complex[Array, '...']); the legacyTensor = Union[jax.Array, NDArray]alias removed (confined to_internal/util.py+ one re-export). Contracts leaned on protocols: aTypeIsguard (graph._is_sparse) narrowsArray | ELL | SectionedELLin both branches across laplacian / connectopy;Monoid/Semigroup/Semiring[S]generics threaded through the algebra surface. NoAny-silencing and no new# type: ignore; type loss across untyped JAX boundaries (jit/vmap/custom_vjp/fori_loop/pallas_call/jnp.linalg.*) is restored with zero-runtime-costtyping.cast. Full per-file test suite green (one regression caught + fixed:cast()to a two-variadic jaxtyping shape fails at runtime, since cast targets are runtime-evaluated -- use'...'). Resolved the ruff<->jaxtyping mismatch by ignoringF722/F821(ruff reads jaxtyping shape strings as forward-ref annotations); mypy's[name-defined]remains the real undefined-name backstop. - Deferred work: pre-existing, non-jaxtyping ruff debt remains (
I001unsorted imports,F401dead imports,F841,E702,E402; ~212, 86 auto-fixable) -- a separate cleanup, untouched here.typing_extensions(forTypeIs) is relied on as a guaranteed-transitive dep via jaxtyping; declaring it directly is optional. - Non-negotiables held: typing-only changes (no numerics / control-flow edits);
every cluster's tests re-run green; no silent
Any/ ignore escapes.
- Type: Downstream deviation
- Triggered by: ilex/SUGAR feedback (
NITRIX_FEEDBACK_ILEX.md, 2026-05-21) -- a GATv2 port built on the ELL-edge surface ran at plausible magnitude but miscomputed because the surface had no self-loop step. - Description: Graph attention attends each vertex to itself -- the
neighbourhood in Velickovic et al. (2018) explicitly includes node
i-- and the GCN renormalisation trick (Kipf & Welling 2017) adds the self-connectionA_hat = A + I. The geometric mesh adjacency (mesh_k_ring_adjacency) is self-loop-free, so a literature-correct GAT / GCN-renorm conv must add the self-edge before aggregating. Shippednitrix.sparse.ell_add_self_loops(ell, edge_attr=None, *, fill='mean'|'add'|'zero', self_value=1.0): appends a per-row(i, i)slot (sibling ofell_pad/ell_mask); for per-edge attributes it fills the self-edge from the row's valid (non-pad) edges --'mean'(the natural default when no intrinsic self-feature exists),'add', or'zero'. Also corrected the GATv2 worked example insemiring/ell_edge.py, which had omitted the self-loop. Framed throughout via the literature, not parity with any particular GNN library. Pure-JAX, jit-safe, differentiable, additive (mypy-clean under the new gate; 5 new tests intests/test_ell.py). - Capability shipped: literature-correct self-attention / renormalisation on the ELL-edge surface, so a GAT / GCN consumer adds one explicit call instead of re-vendoring the self-loop + masked-mean-fill (the bit consumers get wrong).
- Deferred work: an aggregation-side convenience wrapper bundling
add-self-loops + aggregate was deliberately not added -- self-loops are
architecture-specific (EdgeConv / DGCNN, MoNet, plain GCN omit them), so bundling
would promote a non-universal default and re-create the silent-default footgun
this finding is about. Revisit only on demonstrated demand, or host the GAT
composition downstream (a
nimoxELLGAT module). - Non-negotiables held: additive (no change to
semiring_ell_edge_aggregate's signature or any existing caller); no framework dependency or concern leakage (notorch_geometric, no PyG-named API); docs grounded in the literature.
- Type: Robustness / test-quality
- Triggered by: BACKLOG B9 -- the 2026-05-21 deflake surfaced that
linalg.residualiseloses the exactresidual + projection == Ydecomposition for ill-conditioned designs; the property tests were capped to the well-conditioned regime to stay green. - Description: Root-caused: the default
method='cholesky'(Cholesky of the GramX Xᵀ) returns NaN for rank-deficientX(p > obs/ collinear columns) -- the singular Gram has no Cholesky factor -- while the already-shippedmethod='svd'path (jnp.linalg.lstsq) is exact there. So the fix is verification + documentation, not new numerics. (1)tests/test_resid.pynow exercises the SVD path across the fullp -> obsandp > obsregime (test_residual_decomposition_svd_robust) and pins the cholesky-NaN / svd-finite contract (test_svd_robust_where_cholesky_degenerates);lstsqwas confirmed to vmap over batch dims. (2) Themethoddocstring now documents the min-norm least-squares semantics, the unique-projection guarantee (why the decomposition is stable even though the coefficients are not), the cholesky NaN failure mode, and thelstsq(rcond=None)cutoff pitfall (preferl2 > 0for deterministic shrinkage of weak directions). - Decision: default stays
cholesky-- ≈2× faster on the common, well-conditioned case (fMRI confound regression isobs >> p);svdis the documented robust opt-in. Makingsvdthe default was considered and rejected on the perf-vs-common-case trade-off (a silent 2× regression for every caller to fix a regime with an explicit escape hatch). - Non-negotiables held: no numerics change to either solve path; default behaviour unchanged; the well-conditioned cholesky property tests stay as the fast-path guard while the new svd tests cover the wide regime.
A few patterns worth knowing if you're picking up partway through:
- The plan trusts the spec. When in doubt about a design question, the SPEC and SPEC_UPDATEs are authoritative. The plan tells you when to build; the spec tells you what to build.
- Phase 1 (Track B) tasks are good first PRs. They are mechanical, small, and unblock downstream consumers immediately. If an agent is new to the codebase, start there.
- Phase 0 fixes must precede the migration that touches them. Don't migrate
covariance.pybefore the JIT-trap fix is in. _refstubs/semiring_gemm.pyis design input only. It is in the wrong house style and does not reflect the strict/relaxed Protocol split from SPEC_UPDATE §3.1. Re-implement.- The deviation log is a contract, not a confessional. It exists so the next agent (in 3 weeks or 3 months) knows what was shipped under pressure and what's still owed. Treat entries as commitments to follow up, not as records of failure.
- CI failures on Pallas / Triton paths are not always your fault. Pallas Triton
is best-effort per JAX. If a CI failure correlates with a
jaxversion bump and reproduces only on Pallas, file the upstream issue and route the affected kernel to the JAX fallback per SPEC_UPDATE_v0.2 §7.2. The plan does not require you to fix Pallas regressions.
- Calendar dates. Phase durations are approximate; the plan is deviation-tolerant by design. Calendar-binding would defeat that.
- Number of people / agents. The two-track structure scales from one agent (alternating phases) to many (running Phase 1 tasks in parallel). The plan is invariant.
- Code style beyond "the house style." That lives in a separate STYLE.md / CONTRIBUTING.md that the plan references but does not duplicate.
- Hypercoil-side migration shims. Whatever hypercoil ships to redirect imports to nitrix during the transition is hypercoil's concern. The nitrix plan ends at "nitrix exports the symbol."
- Downstream library timelines. thrux, nimox, ilex, entense, bitsjax have their own plans; this plan ends at the contract boundary in SPEC §5.3.