Skip to content

feat: built-in LAPACK / cuSOLVER linalg custom calls#186

Closed
sebffischer wants to merge 2 commits into
mainfrom
feat-builtin-linalg
Closed

feat: built-in LAPACK / cuSOLVER linalg custom calls#186
sebffischer wants to merge 2 commits into
mainfrom
feat-builtin-linalg

Conversation

@sebffischer
Copy link
Copy Markdown
Collaborator

Summary

  • Adds qr, lu, svd, eigh as pjrt-owned FFI handlers, registered under those target names from R/zzz.R. Downstream packages consume them via stablehlo.custom_call and no longer need to ship their own LAPACK linkage.
  • Introduces a small shared FFI kit (ffi_common.h, ffi_lapack.h, ffi_cusolver.{h,cpp}) — status macros, dtype dispatch, dim_to_int, the Lapack<T> trait that absorbs the Windows f32-promotion (Rlapack.dll has no single-precision routines), and the cuSOLVER Solver prologue (handle pool + devInfo).
  • Each kernel is 80–130 lines on top of the kit. CUDA kernels are always defined (Windows path returns Unimplemented) so the Rcpp::export wrappers resolve cleanly without #ifdefs.
  • LAPACK link added to src/Makevars.in via $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS).
  • New vignette vignettes/articles/custom-calls-lapack-cusolver.Rmd documents the recipe for adding new built-in custom calls.

Test plan

  • 207 expectations in tests/testthat/test-linalg.R (square / tall / wide / 1x1 / identity / forced-pivot / ill-conditioned in both f32 and f64; reconstruction, orthogonality, base-R reference; handle-pool reuse).
  • Existing pjrt suite still passes (612 PASS / 0 FAIL).
  • CUDA path: not exercised in CI here — verify with PJRT_PLATFORM=cuda.
  • Windows: not exercised here — verify in CI.

🤖 Generated with Claude Code

Adds qr, lu, svd, eigh as pjrt-owned FFI handlers registered under those
target names from R/zzz.R. Downstream packages (anvl) consume them via
stablehlo.custom_call and no longer need to ship their own LAPACK linkage.

Shared FFI kit:
- ffi_common.h         status macro, dim-to-int, dispatch dtype
- ffi_lapack.h         LAPACK extern decls + Lapack<T> trait that absorbs
                       the Windows f32-promotion (Rlapack.dll has no
                       single-precision routines)
- ffi_cusolver.{h,cpp} GpuLibs (dlopen-loaded function table), DeviceMem
                       RAII, per-stream HandleGuard pool, Solver prologue

Per kernel: 80-130 line .cpp using the kit. CUDA kernels are always
defined (Windows path returns Unimplemented) so the Rcpp::export wrappers
resolve cleanly without #ifdefs.

LAPACK link added to src/Makevars.in via $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS).
Property tests in tests/testthat/test-linalg.R: 207 expectations covering
square / tall / wide / 1x1 / identity / forced-pivot / ill-conditioned
inputs in both f32 and f64, base-R reference comparisons, and handle-pool
reuse.

Vignette at vignettes/articles/custom-calls-lapack-cusolver.Rmd documents
the recipe for adding new built-in custom calls.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
sebffischer added a commit to r-xla/anvl that referenced this pull request May 2, 2026
…v_solve

Companion to r-xla/pjrt#186, which moves the QR custom-call infrastructure
out of anvl/src and grows it to cover lu/svd/eigh as pjrt-built-in handlers.
This commit removes anvl's src/ entirely (no more Rcpp dep) and consumes
those handlers through stablehlo.custom_call, plus adds R-side primitives,
generics, autodiff rules, and a few new wrappers.

Primitives (R/primitives.R + R/rules-stablehlo.R):
- nvl_lu / nvl_svd / nvl_eigh, mirroring nvl_qr (existing). Stablehlo rules
  emit hlo_custom_call against pjrt's "qr" / "lu" / "svd" / "eigh" target
  names with column-major operand/result layouts.

User-facing wrappers (R/api.R):
- nv_lu / nv_svd / nv_eigh return named lists.
- nv_det / nv_logdet built on nvl_lu (sign of permutation x prod of diag U).
- nv_inv via solve(A, I).
- nv_eigen, a base::eigen-shaped wrapper around nv_eigh (descending
  eigenvalues, errors on symmetric=FALSE / only.values=TRUE).
- nv_solve switched from Cholesky to LU: the algorithm is now
  nvl_lu + an in-graph nvl_while permutation + two nvl_triangular_solve.
  Works on any non-singular square matrix (was SPD-only via Cholesky).

S3 generics on AnvilArray (R/api-generics.R):
- solve()  (with solve(A) returning the inverse), qr(), chol(),
  determinant() now dispatch to the corresponding nv_* implementations.
  base R's eigen() is not an S3 generic, so use nv_eigen() directly.

Reverse-mode autodiff (R/rules-reverse.R):
- nvl_eigh: Giles (2008) formula, F-matrix construction with the
  eye-trick to avoid divide-by-zero on the diagonal.
- nvl_svd: Townsend (2016), with rectangular corrections for tall and
  wide cases.
- nvl_lu: square only; PyTorch-style L^{-T} B U^{-T} with P^T applied
  by replaying the pivot swaps in reverse.

Tests:
- tests/testthat/test-api.R: 99 PASS, covers the new wrappers + generics.
- tests/testthat/test-primitives-reverse.R: 4 finite-difference checks
  against analytical gradients for eigh / svd (sq/tall/wide) / lu / nv_solve.
- inst/extra-tests/test-primitives-reverse-torch.R: corresponding torch
  comparisons (CI-only).

DESCRIPTION:
- Drops Rcpp Imports + LinkingTo (no more anvl C++).
- Pins Remotes to r-xla/pjrt@feat-builtin-linalg until that lands.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
Replace per-call cuMemAlloc/cuMemFree (via dlopen'd CUDA driver) with
xla::ffi::ScratchAllocator for all device-side workspace, working copies,
and devInfo. Allocations now come from XLA's BFC pool -- pooled across
calls and visible to XLA's memory accounting -- matching jaxlib's
solver_kernels_ffi.cc pattern.

ffi_cusolver.h/.cpp: drop DeviceMem RAII; drop mem_alloc/mem_free/
stream_sync from GpuLibs (memcpy_dtod and memset_d8 stay). Solver::begin
takes ScratchAllocator& and stores devInfo as int*. allocate_workspace<T>
helper now wraps scratch.Allocate and translates nullopt to Error.

Each *_cuda.cpp: handler binding adds .Ctx<ScratchAllocator>() after
.Ctx<PlatformStream<void*>>(); do_*_cuda takes the allocator by value and
threads it as a reference into the templated impl. Every DeviceMem
replaced with a typed pointer from allocate_workspace. QR's geqrf/orgqr
workspace-reuse logic is preserved.

vignettes/articles/custom-calls-lapack-cusolver.Rmd: update execution
model, dlopen rationale, devInfo, status-codes, failure-mode table, kit
description, and the "add a new linalg op" recipe to reflect the
ScratchAllocator pattern.

Co-Authored-By: Claude Opus 4.7 (1M context) <noreply@anthropic.com>
@sebffischer sebffischer closed this May 8, 2026
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant