Ideas to revisit when a concrete consumer ask, perf regression, or phase boundary surfaces them. Items here are not commitments — they're parked because (a) no consumer is currently blocked, or (b) the cost / payoff ratio favours waiting for more signal. Each entry should name the trigger that would promote it to a sprint.
- One section per item. Lead with a one-sentence summary, then a Trigger line stating the condition under which we'd revisit, then Notes with the design context that would otherwise evaporate.
- When an item graduates to a sprint, move it to
IMPLEMENTATION_PLAN.md(or fold into an open phase) and replace the body here with a one-line pointer at the merging commit.
Add bench scripts under bench/ for
semiring_ell_edge_aggregate (GCN forward + bwd, DGCNN forward
- bwd),
max_pool_with_indices_nd/max_unpool_nd, and the mesh wrappers at icosphere-realistic shapes (ico_6 = 40962verts,ico_7 = 163842verts). Wire results into the op-matrixperf_ratiocolumn.
Trigger. First concrete consumer running these in a training
loop where wall-time matters (Topofit cascade at ico_7, any
SphereMorph descendant, any encoder-decoder UNet variant from
the PGlandsSeg lineage). Or: a Pallas-dispatch experiment
(B3) needing a baseline.
Notes. The op matrix shows all four transformations passing green on these surfaces; what's missing is the wall-time characterisation against a natural reference. Natural references:
semiring_ell_edge_aggregatevs PyTorch GeometricMessagePassingon the same graph (host-side parity check, JAX timing on GPU)max_pool_with_indices_ndvstorch.nn.MaxPool3d(return_indices=True)(the PGlandsSeg reference)
The current implementation is JAX-only: gather + nested vmap +
semiring axis-reduction. At Topofit-scale (ico_7, d_in = 64-128) the inner edge-MLP is a tiled matmul that should
map to the existing semiring_matmul Pallas backend. The
feedback flagged Pallas-fast as a stretch goal.
Trigger. B2 baseline shows the JAX path is the bottleneck in a real training loop, or a consumer explicitly asks.
Notes. The hard part is that edge_fn is a user-supplied
Python callable, so we can't precompile a kernel — we'd need
either (a) a Pallas template parameterised by a fixed-shape MLP
descriptor (works for DGCNN-shaped edge_fns; not generic), or
(b) the KeOps-style symbolic formula approach in B5. Decide
between these only after B2 shows the JAX path matters.
The current implementation supports REAL, TROPICAL_MAX_PLUS,
TROPICAL_MIN_PLUS. LOG / EUCLIDEAN raise
NotImplementedError with a specific message; BOOLEAN is
excluded for footgun reasons (most edge_fns produce floats).
Trigger. A consumer presents a concrete use case where one
of these would compose naturally. The most likely is LOG for
attention-style softmax aggregation, but the
ell_row_softmax-then-REAL composition is usually cleaner
than a fused logsumexp at the edge.
Notes. The shape of the per-edge message is user-controlled,
so the algebra's identity-action on the message must be
user-respected (the same contract as the existing semirings).
LOG: logsumexp over the neighbour axis on
log(edge_fn_output)-type messages. EUCLIDEAN:
sqrt(sum_p edge_fn ** 2)-shaped reductions.
KeOps's Genred (Generic Reductions) takes a user-supplied
mathematical formula expressed in a string DSL and JIT-compiles
a CUDA kernel that streams the reduction without materialising
the per-edge tensor. The SPEC's §3.1 "KeOps-style streaming
kernel" claim already invokes this lineage; what's missing is a
concrete research note on whether a Pallas-backed analogue
could host arbitrary user formulas (the edge_aggregate path's
generic case).
Trigger. B3 forces a decision between "Pallas template per edge_fn shape" and "compile a formula on the fly". Or: a follow-up SPEC update revisits the formula-compilation idea independently.
Notes. The research deliverable is a design doc covering:
- KeOps's formula DSL (LazyTensor + reduction primitives).
- What the JAX / Pallas analogue would look like (jaxpr-level analysis of a user-supplied callable? Pallas templating via concrete-shape inputs?).
- The boundary between formulas that can be compiled (closed polynomial / elementwise) and those that can't (data- dependent control flow, dynamic shapes).
- Whether this belongs in nitrix or as a separate package.
Cross-references.
docs/design/streaming-kernel.mdalready documents the currentO(M·N)peak-HBM claim and the deliberate avoidance of tensor cores.docs/design/ell-on-triton.mdfor the current Triton / Pallas gap.- KeOps paper: Charlier et al. 2021, "Kernel Operations on the GPU, with Autodiff, without Memory Overflows", JMLR.
smoothing.gaussian is a separable n-D convolution implemented
via lax.conv_general_dilated (one 1-D pass per axis), which
lowers to cuDNN on Ampere+ -- a strong baseline. A consumer
flagged a Pallas kernel as a (low-priority) eventual want.
Trigger. A consumer with a wall-time wall on large-3-D Gaussian blur (e.g. repeated 256³ smoothing inside a training loop), and a benchmark showing the separable-conv path is the bottleneck.
Notes. Unlike the ELL / trilinear gather case, the Gaussian is a
stencil, not a gather, so it is Pallas-friendly (the dense
semiring_matmul Pallas kernel already works). But cuDNN is hard
to beat per-pass; the only real Pallas win is fusing the 3
separable axis passes (and the boundary pad) into one kernel to
save the inter-pass HBM round-trips on large volumes. That is
marginal and bandwidth-bound; do not build speculatively. Any kernel
ships behind the existing backend= dispatch with the
conv_general_dilated path as the JAX floor (non-negotiable
§2.2.3) and a golden-corpus parity test. Bench against
lax.conv_general_dilated (not against a naive loop) so the
comparison is honest.
geometry.spatial_transform / integrate_velocity_field resample
via jax.scipy.ndimage.map_coordinates(order=1). A consumer asked
for a Pallas kernel; the "benchmark first" baseline lives in
bench/trilinear_resample.py / bench/PERF_TRILINEAR.md.
Trigger. Both of: (a) the baseline shows the resampling path is a
real bottleneck in a consumer training loop, and (b) a pointer-load
(pl.load / pl.ds) Pallas prototype clears the gather-lowering
risk -- i.e. it avoids the HLO gather primitive that Triton on the
pinned JAX cannot lower (the same blocker as ELL; see
bench/G0_ELL_REPORT.md).
Notes. Trilinear resampling is structurally a gather (8 corner
voxels at data-dependent positions), so it inherits the ELL gather
blocker. The realistic near-term outcome is JAX-default (current
state) until upstream Pallas grows gather, unless the pointer-load
formulation works. A kernel must register a custom_vjp (the
backward is a scatter-add) and keep the map_coordinates path as
the JAX floor. Cross-reference: bench/G0_ELL_REPORT.md for the
gather-lowering precedent.
Baseline (A10G, jax 0.10.0; bench/PERF_TRILINEAR.md). Forward
spatial_transform is fast in absolute terms: 256³ in ~3.1 ms,
192³ in ~1.4 ms; fwd+bwd 256³ in ~13 ms. Cheap interim win, no
Pallas: an explicit pure-JAX 8-corner gather is ~1.5–1.7× faster
than map_coordinates (256³: 1.80 ms vs 3.14 ms), because
map_coordinates carries dispatch / generality overhead this op
doesn't need. If a consumer hits a resampling wall before the Pallas
gate clears, swap _gather_coords_linear to the explicit 8-corner
form first (it's the same gather XLA already lowers, just leaner).
Semiring.identity is the monoid (additive) identity (the
monoid.init value). ELL padding and sparse.ell_mask
(medial-wall / grey-matter masking) actually need the
(*)-annihilator -- the z with z (*) b = monoid_identity
for all b -- so a masked edge is a no-op in
(+)_p values[i,p] (*) B[...]. For every built-in except
EUCLIDEAN the two coincide (REAL 0, LOG/TROPICAL_MAX -inf,
TROPICAL_MIN +inf, BOOLEAN False), which is why identity
currently doubles as the masking value. EUCLIDEAN's (a-b)**2
has no annihilator yet identity == 0.0, so masking by that
value silently injects B[idx]**2 instead of vanishing.
Consideration. Add an explicit annihilator: Optional[...]
field to Semiring (None for EUCLIDEAN). Then ell_mask
takes a semiring= and pulls semiring.annihilator, raising a
clear error when it is None, instead of overloading identity
and relying on the caller to know the distinction. Overloading
identity for two concepts (monoid neutral vs multiplicative
absorber) is a confusion risk as the algebra set grows.
Trigger. A second masking consumer, a user-defined semiring
whose monoid identity and annihilator differ, or any confusion
report. Until then ell_mask(ell, valid, *, identity=...) takes
an explicit value and the docstring + tests/test_ell_masking_semirings.py
document the distinction (incl. the EUCLIDEAN exception).
Effort. S. One field + a guarded ell_mask(semiring=...)
overload; backward-compatible with the explicit-identity form.
linalg.residualise loses the exact residual + projection == Y
decomposition (at atol=1e-5, float32) when the design matrix X
is ill-conditioned -- i.e. as the number of regressors p approaches
the observation count obs and the Gram X Xᵀ becomes near-
singular. This is a long-standing, documented limitation (it
predates the functional -> linalg migration; both
implementations are numerically identical and both fail it), recorded
in the original tests/test_resid.py generator comment ("too good at
coming up with adversarial examples ... will likely need a strong
background in error analysis and numerical linear algebra").
It was previously masked because the hypothesis property tests rarely
reached the ill-conditioned regime under the default deadline / health
checks; the 2026-05-21 conftest deflake (deadline=None) made
hypothesis explore fully and surface it deterministically. The tests
now constrain generate_valid_arrays(well_conditioned=True) to
p <= obs/2 (the well-conditioned regime where the property holds),
so the suite is honest and green, and this entry tracks the real fix.
Trigger. A consumer needs residualisation of near-rank-deficient
designs (p close to obs), or we invest in the numerics.
Notes / candidate fix. Replace the Cholesky-of-Gram solve with an
SVD- / QR-based projector (Q Qᵀ from a thin QR of Xᵀ, or an
SVD with a singular-value floor), which is stable for rank-deficient /
ill-conditioned X. residualise already exposes method=;
add method='svd' (or 'qr') as the robust path and consider
making it the default. Pair the fix with re-widening the
well_conditioned cap (or removing it) in test_resid so the
exact-decomposition properties once again exercise p -> obs.
- B1. Move resolved findings in NITRIX_FEEDBACK_ILEX.md — done 2026-05-20. The gaussian-docstring, edge-aggregate/mesh-conv-stack, max_pool_with_indices, and SUGAR-delta findings are now in that doc's "Resolved findings" section; the open entries carry RESOLVED markers. (Commit refs land on merge.)