feat: built-in LAPACK / cuSOLVER linalg custom calls#186
Closed
sebffischer wants to merge 2 commits into
Closed
Conversation
Adds qr, lu, svd, eigh as pjrt-owned FFI handlers registered under those
target names from R/zzz.R. Downstream packages (anvl) consume them via
stablehlo.custom_call and no longer need to ship their own LAPACK linkage.
Shared FFI kit:
- ffi_common.h status macro, dim-to-int, dispatch dtype
- ffi_lapack.h LAPACK extern decls + Lapack<T> trait that absorbs
the Windows f32-promotion (Rlapack.dll has no
single-precision routines)
- ffi_cusolver.{h,cpp} GpuLibs (dlopen-loaded function table), DeviceMem
RAII, per-stream HandleGuard pool, Solver prologue
Per kernel: 80-130 line .cpp using the kit. CUDA kernels are always
defined (Windows path returns Unimplemented) so the Rcpp::export wrappers
resolve cleanly without #ifdefs.
LAPACK link added to src/Makevars.in via $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS).
Property tests in tests/testthat/test-linalg.R: 207 expectations covering
square / tall / wide / 1x1 / identity / forced-pivot / ill-conditioned
inputs in both f32 and f64, base-R reference comparisons, and handle-pool
reuse.
Vignette at vignettes/articles/custom-calls-lapack-cusolver.Rmd documents
the recipe for adding new built-in custom calls.
Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sebffischer
added a commit
to r-xla/anvl
that referenced
this pull request
May 2, 2026
…v_solve Companion to r-xla/pjrt#186, which moves the QR custom-call infrastructure out of anvl/src and grows it to cover lu/svd/eigh as pjrt-built-in handlers. This commit removes anvl's src/ entirely (no more Rcpp dep) and consumes those handlers through stablehlo.custom_call, plus adds R-side primitives, generics, autodiff rules, and a few new wrappers. Primitives (R/primitives.R + R/rules-stablehlo.R): - nvl_lu / nvl_svd / nvl_eigh, mirroring nvl_qr (existing). Stablehlo rules emit hlo_custom_call against pjrt's "qr" / "lu" / "svd" / "eigh" target names with column-major operand/result layouts. User-facing wrappers (R/api.R): - nv_lu / nv_svd / nv_eigh return named lists. - nv_det / nv_logdet built on nvl_lu (sign of permutation x prod of diag U). - nv_inv via solve(A, I). - nv_eigen, a base::eigen-shaped wrapper around nv_eigh (descending eigenvalues, errors on symmetric=FALSE / only.values=TRUE). - nv_solve switched from Cholesky to LU: the algorithm is now nvl_lu + an in-graph nvl_while permutation + two nvl_triangular_solve. Works on any non-singular square matrix (was SPD-only via Cholesky). S3 generics on AnvilArray (R/api-generics.R): - solve() (with solve(A) returning the inverse), qr(), chol(), determinant() now dispatch to the corresponding nv_* implementations. base R's eigen() is not an S3 generic, so use nv_eigen() directly. Reverse-mode autodiff (R/rules-reverse.R): - nvl_eigh: Giles (2008) formula, F-matrix construction with the eye-trick to avoid divide-by-zero on the diagonal. - nvl_svd: Townsend (2016), with rectangular corrections for tall and wide cases. - nvl_lu: square only; PyTorch-style L^{-T} B U^{-T} with P^T applied by replaying the pivot swaps in reverse. Tests: - tests/testthat/test-api.R: 99 PASS, covers the new wrappers + generics. - tests/testthat/test-primitives-reverse.R: 4 finite-difference checks against analytical gradients for eigh / svd (sq/tall/wide) / lu / nv_solve. - inst/extra-tests/test-primitives-reverse-torch.R: corresponding torch comparisons (CI-only). DESCRIPTION: - Drops Rcpp Imports + LinkingTo (no more anvl C++). - Pins Remotes to r-xla/pjrt@feat-builtin-linalg until that lands. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
5 tasks
Replace per-call cuMemAlloc/cuMemFree (via dlopen'd CUDA driver) with xla::ffi::ScratchAllocator for all device-side workspace, working copies, and devInfo. Allocations now come from XLA's BFC pool -- pooled across calls and visible to XLA's memory accounting -- matching jaxlib's solver_kernels_ffi.cc pattern. ffi_cusolver.h/.cpp: drop DeviceMem RAII; drop mem_alloc/mem_free/ stream_sync from GpuLibs (memcpy_dtod and memset_d8 stay). Solver::begin takes ScratchAllocator& and stores devInfo as int*. allocate_workspace<T> helper now wraps scratch.Allocate and translates nullopt to Error. Each *_cuda.cpp: handler binding adds .Ctx<ScratchAllocator>() after .Ctx<PlatformStream<void*>>(); do_*_cuda takes the allocator by value and threads it as a reference into the templated impl. Every DeviceMem replaced with a typed pointer from allocate_workspace. QR's geqrf/orgqr workspace-reuse logic is preserved. vignettes/articles/custom-calls-lapack-cusolver.Rmd: update execution model, dlopen rationale, devInfo, status-codes, failure-mode table, kit description, and the "add a new linalg op" recipe to reflect the ScratchAllocator pattern. Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment
Add this suggestion to a batch that can be applied as a single commit.This suggestion is invalid because no changes were made to the code.Suggestions cannot be applied while the pull request is closed.Suggestions cannot be applied while viewing a subset of changes.Only one suggestion per line can be applied in a batch.Add this suggestion to a batch that can be applied as a single commit.Applying suggestions on deleted lines is not supported.You must change the existing code in this line in order to create a valid suggestion.Outdated suggestions cannot be applied.This suggestion has been applied or marked resolved.Suggestions cannot be applied from pending reviews.Suggestions cannot be applied on multi-line comments.Suggestions cannot be applied while the pull request is queued to merge.Suggestion cannot be applied right now. Please check back later.
Summary
R/zzz.R. Downstream packages consume them viastablehlo.custom_calland no longer need to ship their own LAPACK linkage.ffi_common.h,ffi_lapack.h,ffi_cusolver.{h,cpp}) — status macros, dtype dispatch,dim_to_int, theLapack<T>trait that absorbs the Windows f32-promotion (Rlapack.dllhas no single-precision routines), and the cuSOLVERSolverprologue (handle pool + devInfo).Unimplemented) so theRcpp::exportwrappers resolve cleanly without#ifdefs.src/Makevars.invia$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS).vignettes/articles/custom-calls-lapack-cusolver.Rmddocuments the recipe for adding new built-in custom calls.Test plan
tests/testthat/test-linalg.R(square / tall / wide / 1x1 / identity / forced-pivot / ill-conditioned in both f32 and f64; reconstruction, orthogonality, base-R reference; handle-pool reuse).PJRT_PLATFORM=cuda.🤖 Generated with Claude Code