From 4024107646140be9e9bbc1f05102da3007f7546f Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 2 May 2026 07:02:47 +0000 Subject: [PATCH 1/2] feat: built-in LAPACK / cuSOLVER linalg custom calls Adds qr, lu, svd, eigh as pjrt-owned FFI handlers registered under those target names from R/zzz.R. Downstream packages (anvl) consume them via stablehlo.custom_call and no longer need to ship their own LAPACK linkage. Shared FFI kit: - ffi_common.h status macro, dim-to-int, dispatch dtype - ffi_lapack.h LAPACK extern decls + Lapack trait that absorbs the Windows f32-promotion (Rlapack.dll has no single-precision routines) - ffi_cusolver.{h,cpp} GpuLibs (dlopen-loaded function table), DeviceMem RAII, per-stream HandleGuard pool, Solver prologue Per kernel: 80-130 line .cpp using the kit. CUDA kernels are always defined (Windows path returns Unimplemented) so the Rcpp::export wrappers resolve cleanly without #ifdefs. LAPACK link added to src/Makevars.in via $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS). Property tests in tests/testthat/test-linalg.R: 207 expectations covering square / tall / wide / 1x1 / identity / forced-pivot / ill-conditioned inputs in both f32 and f64, base-R reference comparisons, and handle-pool reuse. Vignette at vignettes/articles/custom-calls-lapack-cusolver.Rmd documents the recipe for adding new built-in custom calls. Co-Authored-By: Claude Opus 4.7 (1M context) --- R/RcppExports.R | 32 + R/zzz.R | 14 + src/Makevars.in | 2 +- src/RcppExports.cpp | 88 +++ src/eigh.cpp | 92 +++ src/eigh_cuda.cpp | 97 +++ src/ffi_common.h | 53 ++ src/ffi_cusolver.cpp | 183 ++++++ src/ffi_cusolver.h | 195 ++++++ src/ffi_lapack.h | 172 +++++ src/lu.cpp | 70 +++ src/lu_cuda.cpp | 95 +++ src/qr.cpp | 100 +++ src/qr_cuda.cpp | 147 +++++ src/svd.cpp | 93 +++ src/svd_cuda.cpp | 113 ++++ tests/testthat/helper-linalg.R | 128 ++++ tests/testthat/test-linalg.R | 385 ++++++++++++ .../articles/custom-calls-lapack-cusolver.Rmd | 589 ++++++++++++++++++ 19 files changed, 2647 insertions(+), 1 deletion(-) create mode 100644 src/eigh.cpp create mode 100644 src/eigh_cuda.cpp create mode 100644 src/ffi_common.h create mode 100644 src/ffi_cusolver.cpp create mode 100644 src/ffi_cusolver.h create mode 100644 src/ffi_lapack.h create mode 100644 src/lu.cpp create mode 100644 src/lu_cuda.cpp create mode 100644 src/qr.cpp create mode 100644 src/qr_cuda.cpp create mode 100644 src/svd.cpp create mode 100644 src/svd_cuda.cpp create mode 100644 tests/testthat/helper-linalg.R create mode 100644 tests/testthat/test-linalg.R create mode 100644 vignettes/articles/custom-calls-lapack-cusolver.Rmd diff --git a/R/RcppExports.R b/R/RcppExports.R index 53756690..136035e8 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -1,6 +1,14 @@ # Generated by using Rcpp::compileAttributes() -> do not edit by hand # Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393 +get_eigh_handler <- function() { + .Call(`_pjrt_get_eigh_handler`) +} + +get_eigh_handler_cuda <- function() { + .Call(`_pjrt_get_eigh_handler_cuda`) +} + impl_register_custom_call <- function(plugin, target_name, handler_ptr, platform_name) { invisible(.Call(`_pjrt_impl_register_custom_call`, plugin, target_name, handler_ptr, platform_name)) } @@ -21,6 +29,14 @@ format_raw_buffer_cpp <- function(data, dtype, shape) { .Call(`_pjrt_format_raw_buffer_cpp`, data, dtype, shape) } +get_lu_handler <- function() { + .Call(`_pjrt_get_lu_handler`) +} + +get_lu_handler_cuda <- function() { + .Call(`_pjrt_get_lu_handler_cuda`) +} + impl_plugin_load <- function(path) { .Call(`_pjrt_impl_plugin_load`, path) } @@ -177,3 +193,19 @@ impl_client_buffer_from_double <- function(client, device, data, dims, dtype) { .Call(`_pjrt_impl_client_buffer_from_double`, client, device, data, dims, dtype) } +get_qr_handler <- function() { + .Call(`_pjrt_get_qr_handler`) +} + +get_qr_handler_cuda <- function() { + .Call(`_pjrt_get_qr_handler_cuda`) +} + +get_svd_handler <- function() { + .Call(`_pjrt_get_svd_handler`) +} + +get_svd_handler_cuda <- function() { + .Call(`_pjrt_get_svd_handler_cuda`) +} + diff --git a/R/zzz.R b/R/zzz.R index 484803cd..f55ea98e 100644 --- a/R/zzz.R +++ b/R/zzz.R @@ -52,6 +52,20 @@ register_namespace_callback <- function(pkgname, namespace, callback) { .package = pkgname ) + # Register the built-in LAPACK / cuSOLVER linear-algebra handlers. These + # are pjrt-owned custom calls that any downstream package (anvl, future + # bindings) can invoke via `stablehlo.custom_call @(...)` without + # having to ship its own LAPACK linkage. + register_linalg_handler <- function(target, host, cuda) { + handlers <- list(host = host) + if (!is.null(cuda)) handlers$cuda <- cuda + pjrt_register_custom_call(target, handlers, .package = pkgname) + } + register_linalg_handler("qr", get_qr_handler(), get_qr_handler_cuda()) + register_linalg_handler("lu", get_lu_handler(), get_lu_handler_cuda()) + register_linalg_handler("svd", get_svd_handler(), get_svd_handler_cuda()) + register_linalg_handler("eigh", get_eigh_handler(), get_eigh_handler_cuda()) + register_namespace_callback(pkgname, "safetensors", function(...) { frameworks <- utils::getFromNamespace( "safetensors_frameworks", diff --git a/src/Makevars.in b/src/Makevars.in index 61749926..218e4dfe 100644 --- a/src/Makevars.in +++ b/src/Makevars.in @@ -2,7 +2,7 @@ CXX_STD=CXX20 PKG_CPPFLAGS=@cflags@ -Iproto -I$(R_PACKAGE_DIR)/include -I../inst/include PKG_CXXFLAGS=$(C_VISIBILITY) -PKG_LIBS=@libs@ +PKG_LIBS=@libs@ $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS) CPP_SOURCES=@cppsrc@ CC_SOURCES=@pbsrc@ diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index 60824e78..327b5ba1 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -11,6 +11,26 @@ Rcpp::Rostream& Rcpp::Rcout = Rcpp::Rcpp_cout_get(); Rcpp::Rostream& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get(); #endif +// get_eigh_handler +SEXP get_eigh_handler(); +RcppExport SEXP _pjrt_get_eigh_handler() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_eigh_handler()); + return rcpp_result_gen; +END_RCPP +} +// get_eigh_handler_cuda +SEXP get_eigh_handler_cuda(); +RcppExport SEXP _pjrt_get_eigh_handler_cuda() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_eigh_handler_cuda()); + return rcpp_result_gen; +END_RCPP +} // impl_register_custom_call void impl_register_custom_call(Rcpp::XPtr plugin, const std::string& target_name, SEXP handler_ptr, const std::string& platform_name); RcppExport SEXP _pjrt_impl_register_custom_call(SEXP pluginSEXP, SEXP target_nameSEXP, SEXP handler_ptrSEXP, SEXP platform_nameSEXP) { @@ -69,6 +89,26 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// get_lu_handler +SEXP get_lu_handler(); +RcppExport SEXP _pjrt_get_lu_handler() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_lu_handler()); + return rcpp_result_gen; +END_RCPP +} +// get_lu_handler_cuda +SEXP get_lu_handler_cuda(); +RcppExport SEXP _pjrt_get_lu_handler_cuda() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_lu_handler_cuda()); + return rcpp_result_gen; +END_RCPP +} // impl_plugin_load Rcpp::XPtr impl_plugin_load(const std::string& path); RcppExport SEXP _pjrt_impl_plugin_load(SEXP pathSEXP) { @@ -531,13 +571,57 @@ BEGIN_RCPP return rcpp_result_gen; END_RCPP } +// get_qr_handler +SEXP get_qr_handler(); +RcppExport SEXP _pjrt_get_qr_handler() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_qr_handler()); + return rcpp_result_gen; +END_RCPP +} +// get_qr_handler_cuda +SEXP get_qr_handler_cuda(); +RcppExport SEXP _pjrt_get_qr_handler_cuda() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_qr_handler_cuda()); + return rcpp_result_gen; +END_RCPP +} +// get_svd_handler +SEXP get_svd_handler(); +RcppExport SEXP _pjrt_get_svd_handler() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_svd_handler()); + return rcpp_result_gen; +END_RCPP +} +// get_svd_handler_cuda +SEXP get_svd_handler_cuda(); +RcppExport SEXP _pjrt_get_svd_handler_cuda() { +BEGIN_RCPP + Rcpp::RObject rcpp_result_gen; + Rcpp::RNGScope rcpp_rngScope_gen; + rcpp_result_gen = Rcpp::wrap(get_svd_handler_cuda()); + return rcpp_result_gen; +END_RCPP +} static const R_CallMethodDef CallEntries[] = { + {"_pjrt_get_eigh_handler", (DL_FUNC) &_pjrt_get_eigh_handler, 0}, + {"_pjrt_get_eigh_handler_cuda", (DL_FUNC) &_pjrt_get_eigh_handler_cuda, 0}, {"_pjrt_impl_register_custom_call", (DL_FUNC) &_pjrt_impl_register_custom_call, 4}, {"_pjrt_get_print_handler", (DL_FUNC) &_pjrt_get_print_handler, 0}, {"_pjrt_get_print_handler_cuda", (DL_FUNC) &_pjrt_get_print_handler_cuda, 0}, {"_pjrt_test_get_extension", (DL_FUNC) &_pjrt_test_get_extension, 2}, {"_pjrt_format_raw_buffer_cpp", (DL_FUNC) &_pjrt_format_raw_buffer_cpp, 3}, + {"_pjrt_get_lu_handler", (DL_FUNC) &_pjrt_get_lu_handler, 0}, + {"_pjrt_get_lu_handler_cuda", (DL_FUNC) &_pjrt_get_lu_handler_cuda, 0}, {"_pjrt_impl_plugin_load", (DL_FUNC) &_pjrt_impl_plugin_load, 1}, {"_pjrt_impl_plugin_client_create", (DL_FUNC) &_pjrt_impl_plugin_client_create, 2}, {"_pjrt_impl_program_load", (DL_FUNC) &_pjrt_impl_program_load, 2}, @@ -577,6 +661,10 @@ static const R_CallMethodDef CallEntries[] = { {"_pjrt_impl_client_buffer_from_integer", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer, 5}, {"_pjrt_impl_client_buffer_from_logical", (DL_FUNC) &_pjrt_impl_client_buffer_from_logical, 5}, {"_pjrt_impl_client_buffer_from_double", (DL_FUNC) &_pjrt_impl_client_buffer_from_double, 5}, + {"_pjrt_get_qr_handler", (DL_FUNC) &_pjrt_get_qr_handler, 0}, + {"_pjrt_get_qr_handler_cuda", (DL_FUNC) &_pjrt_get_qr_handler_cuda, 0}, + {"_pjrt_get_svd_handler", (DL_FUNC) &_pjrt_get_svd_handler, 0}, + {"_pjrt_get_svd_handler_cuda", (DL_FUNC) &_pjrt_get_svd_handler_cuda, 0}, {NULL, NULL, 0} }; diff --git a/src/eigh.cpp b/src/eigh.cpp new file mode 100644 index 00000000..9fd0de45 --- /dev/null +++ b/src/eigh.cpp @@ -0,0 +1,92 @@ +// Symmetric eigendecomposition via LAPACK syevd. +// +// Input is interpreted as symmetric using only its lower triangle (we pass +// uplo = 'L'). The factorisation produces: +// V : (n, n) eigenvectors as columns (orthonormal) +// W : (n,) eigenvalues in ascending order +// such that A = V diag(W) V^T (using the lower triangle of A). +// +// jobz = 'V' (compute eigenvectors). For values-only the user can drop V. +// +// syevd needs *two* workspaces (real `work` of size lwork, integer `iwork` of +// size liwork) -- both queried with the lwork = -1 idiom. +#include + +#include "ffi_common.h" +#include "ffi_lapack.h" + +#include +#include + +using namespace xla::ffi; + +namespace rpjrt { + +template +static Error eigh_impl(AnyBuffer input, Result v_out, + Result w_out) { + using S = typename Lapack::S; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + if (m != n) + return Error::InvalidArgument("eigh requires a square matrix"); + + // syevd overwrites A with eigenvectors. Copy input -> V output buffer and + // factorise in place there. + std::vector a(static_cast(n) * n); + const T *in = static_cast(input.untyped_data()); + for (std::size_t i = 0; i < a.size(); i++) + a[i] = static_cast(in[i]); + + std::vector w(n); + + const char jobz = 'V'; + const char uplo = 'L'; + int info; + + int lwork = -1, liwork = -1; + S work_size; + int iwork_size; + Lapack::syevd(&jobz, &uplo, &n, a.data(), &n, w.data(), &work_size, &lwork, + &iwork_size, &liwork, &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "syevd workspace query")); + + lwork = static_cast(work_size); + liwork = iwork_size; + std::vector work(lwork); + std::vector iwork(liwork); + Lapack::syevd(&jobz, &uplo, &n, a.data(), &n, w.data(), work.data(), + &lwork, iwork.data(), &liwork, &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "syevd")); + + T *v_data = static_cast((*v_out).untyped_data()); + for (std::size_t i = 0; i < a.size(); i++) + v_data[i] = static_cast(a[i]); + + T *w_data = static_cast((*w_out).untyped_data()); + for (int i = 0; i < n; i++) + w_data[i] = static_cast(w[i]); + + return Error::Success(); +} + +static Error do_eigh(AnyBuffer input, Result v_out, + Result w_out) { + PJRT_DISPATCH_FLOAT(input.element_type(), eigh_impl, input, v_out, w_out); +} + +XLA_FFI_DEFINE_HANDLER(eigh_handler, do_eigh, + Ffi::Bind() + .Arg() // symmetric matrix + .Ret() // eigenvectors (n, n) + .Ret()); // eigenvalues (n,) + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_eigh_handler() { + return R_MakeExternalPtr((void *)rpjrt::eigh_handler, R_NilValue, R_NilValue); +} diff --git a/src/eigh_cuda.cpp b/src/eigh_cuda.cpp new file mode 100644 index 00000000..ef696889 --- /dev/null +++ b/src/eigh_cuda.cpp @@ -0,0 +1,97 @@ +// CUDA symmetric eigendecomposition via cuSOLVER syevd. +// +// jobz / uplo are passed as int enum values: +// jobz = CUSOLVER_EIG_MODE_VECTOR (1) +// uplo = CUBLAS_FILL_MODE_LOWER (0) +// matching the LAPACK 'V' / 'L' choice in src/eigh.cpp. +#include + +#include "ffi_common.h" + +#ifndef _WIN32 +#include "ffi_cusolver.h" + +#include +#endif + +using namespace xla::ffi; + +namespace rpjrt { + +#ifndef _WIN32 +template +static Error eigh_cuda_impl(void *stream, AnyBuffer input, + Result v_out, Result w_out) { + Solver solver(get_gpu_libs()); + PJRT_RETURN_IF_ERROR(solver.begin(stream)); + auto &g = solver.g; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + if (m != n) + return Error::InvalidArgument("eigh requires a square matrix"); + + auto input_ptr = reinterpret_cast(input.untyped_data()); + auto v_ptr = reinterpret_cast((*v_out).untyped_data()); + auto w_ptr = reinterpret_cast((*w_out).untyped_data()); + + std::size_t a_bytes = static_cast(n) * n * sizeof(T); + + // syevd overwrites its A argument with eigenvectors -- copy input into V + // first, factorise in place there. + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(v_ptr, input_ptr, a_bytes, stream), + "cuMemcpyDtoDAsync (input -> V)"); + + const int jobz = 1; // CUSOLVER_EIG_MODE_VECTOR + const int uplo = 0; // CUBLAS_FILL_MODE_LOWER + + int lwork = 0; + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::syevd_bs(g)(solver.handle.get(), jobz, uplo, n, + reinterpret_cast(v_ptr), n, + reinterpret_cast(w_ptr), &lwork), + "cusolverDn?syevd_bufferSize"); + + DeviceMem d_work(g); + PJRT_RETURN_IF_ERROR( + allocate_workspace(lwork, "cuMemAlloc (syevd workspace)", d_work)); + + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::syevd(g)(solver.handle.get(), jobz, uplo, n, + reinterpret_cast(v_ptr), n, + reinterpret_cast(w_ptr), + reinterpret_cast(d_work.ptr), lwork, + reinterpret_cast(solver.info.ptr)), + "cusolverDn?syevd"); + + return Error::Success(); +} +#endif // _WIN32 + +static Error do_eigh_cuda(void *stream, AnyBuffer input, + Result v_out, Result w_out) { +#ifdef _WIN32 + return Error(ErrorCode::kUnimplemented, + "CUDA eigh is not supported on Windows"); +#else + PJRT_DISPATCH_FLOAT(input.element_type(), eigh_cuda_impl, stream, input, + v_out, w_out); +#endif +} + +XLA_FFI_DEFINE_HANDLER(eigh_handler_cuda, do_eigh_cuda, + Ffi::Bind() + .Ctx>() + .Arg() + .Ret() + .Ret()); + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_eigh_handler_cuda() { + return R_MakeExternalPtr((void *)rpjrt::eigh_handler_cuda, R_NilValue, + R_NilValue); +} diff --git a/src/ffi_common.h b/src/ffi_common.h new file mode 100644 index 00000000..aef40c9d --- /dev/null +++ b/src/ffi_common.h @@ -0,0 +1,53 @@ +// Shared utilities for XLA FFI handlers backed by LAPACK / cuSOLVER. +// +// The patterns here mirror the ones in jaxlib (jaxlib/cpu/lapack_kernels.cc and +// jaxlib/gpu/solver_kernels_ffi.cc): an Error-returning macro for +// status-returning calls, a bounds-checked int64 -> int cast, and a +// dispatch-by-element-type macro that reduces an op's f32/f64 boilerplate to a +// one-line switch. +#pragma once + +#include "xla/ffi/api/ffi.h" + +#include +#include +#include + +#define PJRT_RETURN_IF_ERROR(expr) \ + do { \ + auto _e = (expr); \ + if (!_e.success()) \ + return _e; \ + } while (0) + +namespace rpjrt { + +// Validate an int64_t dimension fits in int (LAPACK / cuSOLVER both use int). +// Mirrors jaxlib's MaybeCastNoOverflow(). +inline xla::ffi::Error dim_to_int(std::int64_t v, const char *name, int &out) { + if (v < 0 || v > std::numeric_limits::max()) { + return xla::ffi::Error::InvalidArgument(std::string(name) + + " dimension out of int range"); + } + out = static_cast(v); + return xla::ffi::Error::Success(); +} + +} // namespace rpjrt + +// Dispatch on a buffer's element_type for f32/f64 only (the only float +// precisions our LAPACK / cuSOLVER paths support). Use as: +// PJRT_DISPATCH_FLOAT(input.element_type(), op_impl, args...) +// where op_impl is a template returning xla::ffi::Error. +#define PJRT_DISPATCH_FLOAT(et, IMPL, ...) \ + do { \ + switch (et) { \ + case xla::ffi::DataType::F32: \ + return IMPL(__VA_ARGS__); \ + case xla::ffi::DataType::F64: \ + return IMPL(__VA_ARGS__); \ + default: \ + return xla::ffi::Error::InvalidArgument( \ + "operation only supports f32 and f64"); \ + } \ + } while (0) diff --git a/src/ffi_cusolver.cpp b/src/ffi_cusolver.cpp new file mode 100644 index 00000000..f5d5f8b6 --- /dev/null +++ b/src/ffi_cusolver.cpp @@ -0,0 +1,183 @@ +// Implementation of the shared cuSOLVER infrastructure: dlopen-based loader, +// device-memory RAII, and a per-stream handle pool. All cuSOLVER-backed +// kernels (qr, lu, svd, eigh) use these singletons so they share one set of +// loaded function pointers and one handle pool. +#include "ffi_cusolver.h" + +#ifndef _WIN32 + +#include + +#include +#include +#include + +using namespace xla::ffi; + +namespace rpjrt { + +template static T load_sym(void *lib, const char *name) { + return reinterpret_cast(dlsym(lib, name)); +} + +GpuLibs &get_gpu_libs() { + static GpuLibs g; + if (g.loaded) + return g; + + // Probe a couple of candidates: SDK installs ship the unversioned symlink, + // runtime-only installs (typical in containers) only ship the SONAME. + void *cusolver = dlopen("libcusolver.so", RTLD_LAZY); + if (!cusolver) + cusolver = dlopen("libcusolver.so.11", RTLD_LAZY); + if (!cusolver) + return g; + + void *cuda = dlopen("libcuda.so.1", RTLD_LAZY); + if (!cuda) + return g; + + g.dn_create = load_sym(cusolver, "cusolverDnCreate"); + g.dn_destroy = + load_sym(cusolver, "cusolverDnDestroy"); + g.dn_set_stream = + load_sym(cusolver, "cusolverDnSetStream"); + + g.s_geqrf_bs = + load_sym(cusolver, "cusolverDnSgeqrf_bufferSize"); + g.s_geqrf = load_sym(cusolver, "cusolverDnSgeqrf"); + g.d_geqrf_bs = + load_sym(cusolver, "cusolverDnDgeqrf_bufferSize"); + g.d_geqrf = load_sym(cusolver, "cusolverDnDgeqrf"); + g.s_orgqr_bs = + load_sym(cusolver, "cusolverDnSorgqr_bufferSize"); + g.s_orgqr = load_sym(cusolver, "cusolverDnSorgqr"); + g.d_orgqr_bs = + load_sym(cusolver, "cusolverDnDorgqr_bufferSize"); + g.d_orgqr = load_sym(cusolver, "cusolverDnDorgqr"); + + g.s_getrf_bs = + load_sym(cusolver, "cusolverDnSgetrf_bufferSize"); + g.s_getrf = load_sym(cusolver, "cusolverDnSgetrf"); + g.d_getrf_bs = + load_sym(cusolver, "cusolverDnDgetrf_bufferSize"); + g.d_getrf = load_sym(cusolver, "cusolverDnDgetrf"); + + g.s_gesvd_bs = + load_sym(cusolver, "cusolverDnSgesvd_bufferSize"); + g.d_gesvd_bs = + load_sym(cusolver, "cusolverDnDgesvd_bufferSize"); + g.s_gesvd = load_sym(cusolver, "cusolverDnSgesvd"); + g.d_gesvd = load_sym(cusolver, "cusolverDnDgesvd"); + + g.s_syevd_bs = + load_sym(cusolver, "cusolverDnSsyevd_bufferSize"); + g.d_syevd_bs = + load_sym(cusolver, "cusolverDnDsyevd_bufferSize"); + g.s_syevd = load_sym(cusolver, "cusolverDnSsyevd"); + g.d_syevd = load_sym(cusolver, "cusolverDnDsyevd"); + + g.mem_alloc = load_sym(cuda, "cuMemAlloc_v2"); + g.mem_free = load_sym(cuda, "cuMemFree_v2"); + g.memcpy_dtod = + load_sym(cuda, "cuMemcpyDtoDAsync_v2"); + g.memset_d8 = load_sym(cuda, "cuMemsetD8Async"); + g.stream_sync = + load_sym(cuda, "cuStreamSynchronize"); + + g.loaded = true; + return g; +} + +DeviceMem::~DeviceMem() { + if (ptr) + g.mem_free(ptr); +} + +int DeviceMem::alloc(std::size_t bytes) { return g.mem_alloc(&ptr, bytes); } + +// Per-stream cuSOLVER handle pool. +// +// cuSOLVER handles are not safe to share across streams: cusolverDnSetStream +// rebinds the handle, racing with concurrent launches issued from another +// stream. We mirror jaxlib's SolverHandlePool (jaxlib/gpu/solver_handle_pool.cc): +// a mutex-guarded free-list of handles per stream, RAII-returned to the pool +// when the borrow goes out of scope. Handles are pooled forever, never +// destroyed (acceptable for a process-wide resource). +namespace { +struct SolverHandlePool { + std::mutex mu; + std::map> free_handles; + + static SolverHandlePool &instance() { + static SolverHandlePool p; + return p; + } +}; +} // namespace + +HandleGuard::HandleGuard(HandleGuard &&o) noexcept + : stream_(o.stream_), handle_(o.handle_) { + o.handle_ = nullptr; +} + +HandleGuard &HandleGuard::operator=(HandleGuard &&o) noexcept { + if (this != &o) { + release(); + stream_ = o.stream_; + handle_ = o.handle_; + o.handle_ = nullptr; + } + return *this; +} + +HandleGuard::~HandleGuard() { release(); } + +void HandleGuard::release() { + if (!handle_) + return; + auto &pool = SolverHandlePool::instance(); + std::lock_guard lock(pool.mu); + pool.free_handles[stream_].push_back(handle_); + handle_ = nullptr; +} + +Error borrow_solver_handle(GpuLibs &g, void *stream, HandleGuard &out) { + auto &pool = SolverHandlePool::instance(); + void *handle = nullptr; + { + std::lock_guard lock(pool.mu); + auto &vec = pool.free_handles[stream]; + if (!vec.empty()) { + handle = vec.back(); + vec.pop_back(); + } + } + if (!handle) { + PJRT_RETURN_IF_GPU_ERROR(g.dn_create(&handle), "cusolverDnCreate"); + } + if (stream) { + int s = g.dn_set_stream(handle, stream); + if (s != 0) { + // Return the handle to the pool so it isn't lost on this error path. + std::lock_guard lock(pool.mu); + pool.free_handles[stream].push_back(handle); + return Error::Internal("cusolverDnSetStream failed with status = " + + std::to_string(s)); + } + } + out = HandleGuard(stream, handle); + return Error::Success(); +} + +Error Solver::begin(void *stream) { + if (!g.loaded) + return Error::Internal("CUDA/cuSOLVER libraries not available"); + PJRT_RETURN_IF_ERROR(borrow_solver_handle(g, stream, handle)); + PJRT_RETURN_IF_GPU_ERROR(info.alloc(sizeof(int)), "cuMemAlloc (devInfo)"); + return Error::Success(); +} + +} // namespace rpjrt + +#endif // _WIN32 diff --git a/src/ffi_cusolver.h b/src/ffi_cusolver.h new file mode 100644 index 00000000..0d5debd7 --- /dev/null +++ b/src/ffi_cusolver.h @@ -0,0 +1,195 @@ +// cuSOLVER + CUDA-driver function table, dynamically loaded via dlopen so the +// R package can be built without the CUDA SDK headers and run on machines +// without a CUDA install. Mirrors the role of jaxlib/gpu/solver_kernels_ffi.cc +// + jaxlib/gpu/solver_handle_pool.cc, adapted to a runtime-link-only model. +// +// Only the non-Windows half is meaningful; on Windows there is no CUDA, and +// dlopen is POSIX-only. +#pragma once + +#include "ffi_common.h" +#include "xla/ffi/api/ffi.h" + +#ifndef _WIN32 + +#include +#include +#include + +namespace rpjrt { + +// Opaque CUDA / cuSOLVER types (no SDK headers needed). uintptr_t matches +// the typedef used elsewhere in pjrt (see ffi.cpp). +using CUdeviceptr = std::uintptr_t; + +// Status-check helper for CUDA driver / cuSOLVER calls. Every API call returns +// an int; we propagate non-zero values as Error::Internal annotated with the +// site name. Mirrors jaxlib's JAX_FFI_RETURN_IF_GPU_ERROR. +#define PJRT_RETURN_IF_GPU_ERROR(expr, what) \ + do { \ + int _status = (expr); \ + if (_status != 0) { \ + return xla::ffi::Error::Internal(std::string(what) + \ + " failed with status = " + \ + std::to_string(_status)); \ + } \ + } while (0) + +// Function pointers for the cuSOLVER + CUDA driver entry points the package +// uses. New ops add their entries here and to the loader in ffi_cusolver.cpp. +struct GpuLibs { + // cuSOLVER handle management. + int (*dn_create)(void **); + int (*dn_destroy)(void *); + int (*dn_set_stream)(void *, void *); + + // QR. + int (*s_geqrf_bs)(void *, int, int, float *, int, int *); + int (*s_geqrf)(void *, int, int, float *, int, float *, float *, int, int *); + int (*d_geqrf_bs)(void *, int, int, double *, int, int *); + int (*d_geqrf)(void *, int, int, double *, int, double *, double *, int, + int *); + int (*s_orgqr_bs)(void *, int, int, int, const float *, int, const float *, + int *); + int (*s_orgqr)(void *, int, int, int, float *, int, const float *, float *, + int, int *); + int (*d_orgqr_bs)(void *, int, int, int, const double *, int, const double *, + int *); + int (*d_orgqr)(void *, int, int, int, double *, int, const double *, + double *, int, int *); + + // LU. ipiv and devInfo are device int32; ipiv is 1-based row indices. + int (*s_getrf_bs)(void *, int, int, float *, int, int *); + int (*s_getrf)(void *, int, int, float *, int, float *, int *, int *); + int (*d_getrf_bs)(void *, int, int, double *, int, int *); + int (*d_getrf)(void *, int, int, double *, int, double *, int *, int *); + + // SVD via cusolverDn?gesvd. The bufferSize variant takes only (handle, m, n) + // and returns the worst-case workspace; jobu/jobvt are not part of the + // workspace query. rwork is unused for real precisions (pass nullptr). + int (*s_gesvd_bs)(void *, int, int, int *); + int (*d_gesvd_bs)(void *, int, int, int *); + int (*s_gesvd)(void *, signed char, signed char, int, int, float *, int, + float *, float *, int, float *, int, float *, int, float *, + int *); + int (*d_gesvd)(void *, signed char, signed char, int, int, double *, int, + double *, double *, int, double *, int, double *, int, + double *, int *); + + // Symmetric/Hermitian eigendecomposition. jobz/uplo are cusolverEigMode_t / + // cublasFillMode_t (both int enums). We always pass jobz = 1 (vectors) and + // uplo = 0 (lower) -- see eigh_cuda.cpp. + int (*s_syevd_bs)(void *, int, int, int, const float *, int, const float *, + int *); + int (*d_syevd_bs)(void *, int, int, int, const double *, int, const double *, + int *); + int (*s_syevd)(void *, int, int, int, float *, int, float *, float *, int, + int *); + int (*d_syevd)(void *, int, int, int, double *, int, double *, double *, int, + int *); + + // CUDA driver. + int (*mem_alloc)(CUdeviceptr *, std::size_t); + int (*mem_free)(CUdeviceptr); + int (*memcpy_dtod)(CUdeviceptr, CUdeviceptr, std::size_t, void *); + int (*memset_d8)(CUdeviceptr, unsigned char, std::size_t, void *); + int (*stream_sync)(void *); + + bool loaded = false; +}; + +GpuLibs &get_gpu_libs(); + +// RAII wrapper for cuMemAlloc'd device memory. +struct DeviceMem { + CUdeviceptr ptr = 0; + GpuLibs &g; + explicit DeviceMem(GpuLibs &g) : g(g) {} + ~DeviceMem(); + DeviceMem(const DeviceMem &) = delete; + DeviceMem &operator=(const DeviceMem &) = delete; + int alloc(std::size_t bytes); +}; + +// Borrowed cuSOLVER handle, returned to the per-stream pool on destruction. +class HandleGuard { +public: + HandleGuard() = default; + HandleGuard(void *stream, void *handle) : stream_(stream), handle_(handle) {} + HandleGuard(HandleGuard &&o) noexcept; + HandleGuard &operator=(HandleGuard &&o) noexcept; + ~HandleGuard(); + HandleGuard(const HandleGuard &) = delete; + HandleGuard &operator=(const HandleGuard &) = delete; + void *get() const { return handle_; } + +private: + void release(); + void *stream_ = nullptr; + void *handle_ = nullptr; +}; + +xla::ffi::Error borrow_solver_handle(GpuLibs &g, void *stream, + HandleGuard &out); + +// Bundled prologue for a CUDA linalg kernel: a borrowed cuSOLVER handle on +// `stream`, plus a pre-allocated device `int` for `devInfo` (every cuSOLVER +// routine wants one). All four built-in linalg kernels open with the same +// three steps -- loaded-check, handle borrow, info alloc -- and `Solver` +// rolls them into one initialiser. `g` and `info` mirror the shape of +// jaxlib's GeqrfImpl prologue (cf. solver_kernels_ffi.cc). +struct Solver { + GpuLibs &g; + HandleGuard handle; + DeviceMem info; + explicit Solver(GpuLibs &g) : g(g), info(g) {} + + // Borrow a handle for `stream` and allocate devInfo. Call once per + // kernel invocation, before any cuSOLVER calls. + xla::ffi::Error begin(void *stream); +}; + +// Allocate `lwork * sizeof(T)` bytes of device memory into `out`, with a +// site-name annotation. Centralises the int -> size_t widening so each +// kernel doesn't open-code it per workspace. +template +xla::ffi::Error allocate_workspace(int lwork, const char *name, + DeviceMem &out) { + std::size_t bytes = static_cast(lwork) * sizeof(T); + PJRT_RETURN_IF_GPU_ERROR(out.alloc(bytes), name); + return xla::ffi::Error::Success(); +} + +// Per-precision dispatch trait for cuSOLVER routines. Modelled on jaxlib's +// `solver::Geqrf` / `solver::Getrf` ... wrappers. +template struct CuSolver; + +template <> struct CuSolver { + static auto geqrf_bs(GpuLibs &g) { return g.s_geqrf_bs; } + static auto geqrf(GpuLibs &g) { return g.s_geqrf; } + static auto orgqr_bs(GpuLibs &g) { return g.s_orgqr_bs; } + static auto orgqr(GpuLibs &g) { return g.s_orgqr; } + static auto getrf_bs(GpuLibs &g) { return g.s_getrf_bs; } + static auto getrf(GpuLibs &g) { return g.s_getrf; } + static auto gesvd_bs(GpuLibs &g) { return g.s_gesvd_bs; } + static auto gesvd(GpuLibs &g) { return g.s_gesvd; } + static auto syevd_bs(GpuLibs &g) { return g.s_syevd_bs; } + static auto syevd(GpuLibs &g) { return g.s_syevd; } +}; + +template <> struct CuSolver { + static auto geqrf_bs(GpuLibs &g) { return g.d_geqrf_bs; } + static auto geqrf(GpuLibs &g) { return g.d_geqrf; } + static auto orgqr_bs(GpuLibs &g) { return g.d_orgqr_bs; } + static auto orgqr(GpuLibs &g) { return g.d_orgqr; } + static auto getrf_bs(GpuLibs &g) { return g.d_getrf_bs; } + static auto getrf(GpuLibs &g) { return g.d_getrf; } + static auto gesvd_bs(GpuLibs &g) { return g.d_gesvd_bs; } + static auto gesvd(GpuLibs &g) { return g.d_gesvd; } + static auto syevd_bs(GpuLibs &g) { return g.d_syevd_bs; } + static auto syevd(GpuLibs &g) { return g.d_syevd; } +}; + +} // namespace rpjrt + +#endif // _WIN32 diff --git a/src/ffi_lapack.h b/src/ffi_lapack.h new file mode 100644 index 00000000..28db91f0 --- /dev/null +++ b/src/ffi_lapack.h @@ -0,0 +1,172 @@ +// LAPACK extern declarations and per-precision dispatch traits. +// +// On macOS/Linux we link against system LAPACK (Accelerate, OpenBLAS, MKL), +// which provides single- and double-precision routines. On Windows R bundles +// its own Rlapack.dll, which only ships the double-precision variants. To +// keep the per-op kernels precision-agnostic, we expose a Lapack trait +// whose ::S typedef is the precision actually used for the LAPACK call: +// - non-Windows: Lapack::S = float, Lapack::S = double +// - Windows: Lapack::S = double (promote), Lapack::S = double +// +// Each kernel writes: +// +// using S = typename Lapack::S; +// std::vector a(...); // promoted copy of input +// Lapack::geqrf(...); +// ... back-cast to T on output ... +// +// and gets the right behaviour on both platforms with no #ifdefs in the +// kernel body. Mirrors the dispatch trait pattern in jaxlib's +// lapack_kernels.cc. +#pragma once + +#include "ffi_common.h" + +extern "C" { +// QR: factorisation + Q materialisation. +void dgeqrf_(const int *m, const int *n, double *a, const int *lda, double *tau, + double *work, const int *lwork, int *info); +void dorgqr_(const int *m, const int *n, const int *k, double *a, + const int *lda, const double *tau, double *work, const int *lwork, + int *info); + +// LU: partial-pivoting LU. ipiv is 1-based row indices; info > 0 means a +// pivot was zero (matrix singular). +void dgetrf_(const int *m, const int *n, double *a, const int *lda, int *ipiv, + int *info); + +// SVD: divide-and-conquer (gesdd) is generally faster than the QR-based +// gesvd for medium/large matrices and is what jaxlib uses on CPU. jobz +// selects which singular vectors to compute: 'A' = full U, V; 'S' = reduced; +// 'O' = overwrite A; 'N' = singular values only. +void dgesdd_(const char *jobz, const int *m, const int *n, double *a, + const int *lda, double *s, double *u, const int *ldu, double *vt, + const int *ldvt, double *work, const int *lwork, int *iwork, + int *info); + +// Symmetric/Hermitian eigendecomposition (real). jobz: 'N' eigenvalues only, +// 'V' eigenvectors too. uplo: 'L' / 'U' selects which triangle of A holds +// the input. +void dsyevd_(const char *jobz, const char *uplo, const int *n, double *a, + const int *lda, double *w, double *work, const int *lwork, + int *iwork, const int *liwork, int *info); + +#ifndef _WIN32 +void sgeqrf_(const int *m, const int *n, float *a, const int *lda, float *tau, + float *work, const int *lwork, int *info); +void sorgqr_(const int *m, const int *n, const int *k, float *a, const int *lda, + const float *tau, float *work, const int *lwork, int *info); +void sgetrf_(const int *m, const int *n, float *a, const int *lda, int *ipiv, + int *info); +void sgesdd_(const char *jobz, const int *m, const int *n, float *a, + const int *lda, float *s, float *u, const int *ldu, float *vt, + const int *ldvt, float *work, const int *lwork, int *iwork, + int *info); +void ssyevd_(const char *jobz, const char *uplo, const int *n, float *a, + const int *lda, float *w, float *work, const int *lwork, + int *iwork, const int *liwork, int *info); +#endif +} + +namespace rpjrt { + +inline xla::ffi::Error lapack_check_info(int info, const char *routine) { + if (info == 0) + return xla::ffi::Error::Success(); + return xla::ffi::Error::Internal(std::string(routine) + + " failed with info = " + + std::to_string(info)); +} + +// Per-precision dispatch trait. ::S is the storage type that the LAPACK +// routines actually see; on Windows this is always double. +template struct Lapack; + +template <> struct Lapack { + using S = double; + static void geqrf(const int *m, const int *n, S *a, const int *lda, S *tau, + S *work, const int *lwork, int *info) { + dgeqrf_(m, n, a, lda, tau, work, lwork, info); + } + static void orgqr(const int *m, const int *n, const int *k, S *a, + const int *lda, const S *tau, S *work, const int *lwork, + int *info) { + dorgqr_(m, n, k, a, lda, tau, work, lwork, info); + } + static void getrf(const int *m, const int *n, S *a, const int *lda, + int *ipiv, int *info) { + dgetrf_(m, n, a, lda, ipiv, info); + } + static void gesdd(const char *jobz, const int *m, const int *n, S *a, + const int *lda, S *s, S *u, const int *ldu, S *vt, + const int *ldvt, S *work, const int *lwork, int *iwork, + int *info) { + dgesdd_(jobz, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, iwork, info); + } + static void syevd(const char *jobz, const char *uplo, const int *n, S *a, + const int *lda, S *w, S *work, const int *lwork, int *iwork, + const int *liwork, int *info) { + dsyevd_(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); + } +}; + +#ifndef _WIN32 +template <> struct Lapack { + using S = float; + static void geqrf(const int *m, const int *n, S *a, const int *lda, S *tau, + S *work, const int *lwork, int *info) { + sgeqrf_(m, n, a, lda, tau, work, lwork, info); + } + static void orgqr(const int *m, const int *n, const int *k, S *a, + const int *lda, const S *tau, S *work, const int *lwork, + int *info) { + sorgqr_(m, n, k, a, lda, tau, work, lwork, info); + } + static void getrf(const int *m, const int *n, S *a, const int *lda, + int *ipiv, int *info) { + sgetrf_(m, n, a, lda, ipiv, info); + } + static void gesdd(const char *jobz, const int *m, const int *n, S *a, + const int *lda, S *s, S *u, const int *ldu, S *vt, + const int *ldvt, S *work, const int *lwork, int *iwork, + int *info) { + sgesdd_(jobz, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, iwork, info); + } + static void syevd(const char *jobz, const char *uplo, const int *n, S *a, + const int *lda, S *w, S *work, const int *lwork, int *iwork, + const int *liwork, int *info) { + ssyevd_(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); + } +}; +#else +// Windows: promote f32 -> f64 -> f32 around the LAPACK call. +template <> struct Lapack { + using S = double; + static void geqrf(const int *m, const int *n, S *a, const int *lda, S *tau, + S *work, const int *lwork, int *info) { + dgeqrf_(m, n, a, lda, tau, work, lwork, info); + } + static void orgqr(const int *m, const int *n, const int *k, S *a, + const int *lda, const S *tau, S *work, const int *lwork, + int *info) { + dorgqr_(m, n, k, a, lda, tau, work, lwork, info); + } + static void getrf(const int *m, const int *n, S *a, const int *lda, + int *ipiv, int *info) { + dgetrf_(m, n, a, lda, ipiv, info); + } + static void gesdd(const char *jobz, const int *m, const int *n, S *a, + const int *lda, S *s, S *u, const int *ldu, S *vt, + const int *ldvt, S *work, const int *lwork, int *iwork, + int *info) { + dgesdd_(jobz, m, n, a, lda, s, u, ldu, vt, ldvt, work, lwork, iwork, info); + } + static void syevd(const char *jobz, const char *uplo, const int *n, S *a, + const int *lda, S *w, S *work, const int *lwork, int *iwork, + const int *liwork, int *info) { + dsyevd_(jobz, uplo, n, a, lda, w, work, lwork, iwork, liwork, info); + } +}; +#endif + +} // namespace rpjrt diff --git a/src/lu.cpp b/src/lu.cpp new file mode 100644 index 00000000..14092c85 --- /dev/null +++ b/src/lu.cpp @@ -0,0 +1,70 @@ +// LU decomposition with partial pivoting: P A = L U. +// +// LAPACK getrf overwrites A with L (strictly below the diagonal, unit +// diagonal implicit) and U (on and above the diagonal). The k = min(m, n) +// pivot indices are returned in `ipiv`, 1-based row swaps applied during +// elimination. +// +// Outputs: +// LU : (m, n), same dtype as input +// pivots : (k,) int32 +#include + +#include "ffi_common.h" +#include "ffi_lapack.h" + +#include +#include +#include + +using namespace xla::ffi; + +namespace rpjrt { + +template +static Error lu_impl(AnyBuffer input, Result lu_out, + Result piv_out) { + using S = typename Lapack::S; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + + std::vector a(static_cast(m) * n); + const T *in = static_cast(input.untyped_data()); + for (std::size_t i = 0; i < a.size(); i++) + a[i] = static_cast(in[i]); + + // getrf has no workspace argument -- pivoting is in place. + int *ipiv = static_cast((*piv_out).untyped_data()); + int info; + Lapack::getrf(&m, &n, a.data(), &m, ipiv, &info); + // info > 0 means U(info, info) = 0 (singular). Surface that as an error; + // the user has no way to act on a "factorised but singular" return. + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "getrf")); + + T *lu_data = static_cast((*lu_out).untyped_data()); + for (std::size_t i = 0; i < a.size(); i++) + lu_data[i] = static_cast(a[i]); + + return Error::Success(); +} + +static Error do_lu(AnyBuffer input, Result lu_out, + Result piv_out) { + PJRT_DISPATCH_FLOAT(input.element_type(), lu_impl, input, lu_out, piv_out); +} + +XLA_FFI_DEFINE_HANDLER(lu_handler, do_lu, + Ffi::Bind() + .Arg() // input matrix + .Ret() // LU (same dtype) + .Ret()); // pivots (int32) + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_lu_handler() { + return R_MakeExternalPtr((void *)rpjrt::lu_handler, R_NilValue, R_NilValue); +} diff --git a/src/lu_cuda.cpp b/src/lu_cuda.cpp new file mode 100644 index 00000000..eebc642d --- /dev/null +++ b/src/lu_cuda.cpp @@ -0,0 +1,95 @@ +// CUDA LU decomposition via cuSOLVER. Mirrors src/lu.cpp on the GPU. +// +// cuSOLVER's getrf places ipiv (and devInfo) directly in device memory -- +// the pivots output buffer XLA hands us is already on-device, so we pass +// it through with no extra D2D copy. +#include + +#include "ffi_common.h" + +#ifndef _WIN32 +#include "ffi_cusolver.h" + +#include +#endif + +using namespace xla::ffi; + +namespace rpjrt { + +#ifndef _WIN32 +template +static Error lu_cuda_impl(void *stream, AnyBuffer input, + Result lu_out, + Result piv_out) { + Solver solver(get_gpu_libs()); + PJRT_RETURN_IF_ERROR(solver.begin(stream)); + auto &g = solver.g; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + + auto input_ptr = reinterpret_cast(input.untyped_data()); + auto lu_ptr = reinterpret_cast((*lu_out).untyped_data()); + auto piv_ptr = reinterpret_cast((*piv_out).untyped_data()); + + std::size_t a_bytes = static_cast(m) * n * sizeof(T); + + // getrf overwrites A in place. Copy input -> LU output buffer first, then + // factorise the LU buffer so we never touch the input. + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(lu_ptr, input_ptr, a_bytes, stream), + "cuMemcpyDtoDAsync (input -> LU)"); + + int lwork = 0; + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::getrf_bs(g)(solver.handle.get(), m, n, + reinterpret_cast(lu_ptr), m, &lwork), + "cusolverDn?getrf_bufferSize"); + + DeviceMem d_work(g); + PJRT_RETURN_IF_ERROR( + allocate_workspace(lwork, "cuMemAlloc (getrf workspace)", d_work)); + + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::getrf(g)(solver.handle.get(), m, n, + reinterpret_cast(lu_ptr), m, + reinterpret_cast(d_work.ptr), + reinterpret_cast(piv_ptr), + reinterpret_cast(solver.info.ptr)), + "cusolverDn?getrf"); + + // devInfo is intentionally not read back: a singular matrix surfaces as + // numerical garbage downstream rather than a launch-time error, matching + // jaxlib's getrf path. + + return Error::Success(); +} +#endif // _WIN32 + +static Error do_lu_cuda(void *stream, AnyBuffer input, + Result lu_out, Result piv_out) { +#ifdef _WIN32 + return Error(ErrorCode::kUnimplemented, + "CUDA LU is not supported on Windows"); +#else + PJRT_DISPATCH_FLOAT(input.element_type(), lu_cuda_impl, stream, input, lu_out, + piv_out); +#endif +} + +XLA_FFI_DEFINE_HANDLER(lu_handler_cuda, do_lu_cuda, + Ffi::Bind() + .Ctx>() + .Arg() + .Ret() + .Ret()); + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_lu_handler_cuda() { + return R_MakeExternalPtr((void *)rpjrt::lu_handler_cuda, R_NilValue, + R_NilValue); +} diff --git a/src/qr.cpp b/src/qr.cpp new file mode 100644 index 00000000..373aa82a --- /dev/null +++ b/src/qr.cpp @@ -0,0 +1,100 @@ +// QR decomposition: A (m x n) -> Q (m x k), R (k x n), where k = min(m, n). +// +// All buffers use column-major layout (specified by operand_layouts / +// result_layouts on the custom_call). XLA handles the row-major <-> +// column-major conversion transparently -- the handler works directly in +// LAPACK-native layout. +// +// The two-phase pattern is geqrf (compute Householder reflectors + R) then +// orgqr (materialise Q from the reflectors). LAPACK's `lwork = -1` workspace +// query is run before each call -- the optimal size depends on +// implementation-specific blocking parameters that we can't compute ahead +// of time. +#include + +#include "ffi_common.h" +#include "ffi_lapack.h" + +#include +#include +#include + +using namespace xla::ffi; + +namespace rpjrt { + +template +static Error qr_impl(AnyBuffer input, Result q_out, + Result r_out) { + // S is the LAPACK storage type: float/double on macOS+Linux, double on + // Windows for both (R's bundled Rlapack.dll has no s* routines, so f32 + // input is promoted to double across the call). + using S = typename Lapack::S; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + int k = std::min(m, n); + + std::vector a(static_cast(m) * n); + const T *in = static_cast(input.untyped_data()); + for (std::size_t i = 0; i < a.size(); i++) + a[i] = static_cast(in[i]); + + std::vector tau(k); + int lwork = -1; + S work_size; + int info; + Lapack::geqrf(&m, &n, a.data(), &m, tau.data(), &work_size, &lwork, &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "geqrf workspace query")); + + lwork = static_cast(work_size); + std::vector work(lwork); + Lapack::geqrf(&m, &n, a.data(), &m, tau.data(), work.data(), &lwork, + &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "geqrf")); + + T *r_data = static_cast((*r_out).untyped_data()); + for (int j = 0; j < n; j++) { + for (int i = 0; i < k; i++) { + r_data[j * k + i] = + (i <= j) ? static_cast(a[j * m + i]) : static_cast(0); + } + } + + lwork = -1; + Lapack::orgqr(&m, &k, &k, a.data(), &m, tau.data(), &work_size, &lwork, + &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "orgqr workspace query")); + + lwork = static_cast(work_size); + work.resize(lwork); + Lapack::orgqr(&m, &k, &k, a.data(), &m, tau.data(), work.data(), &lwork, + &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "orgqr")); + + T *q_data = static_cast((*q_out).untyped_data()); + for (std::size_t i = 0; i < static_cast(m) * k; i++) + q_data[i] = static_cast(a[i]); + + return Error::Success(); +} + +static Error do_qr(AnyBuffer input, Result q_out, + Result r_out) { + PJRT_DISPATCH_FLOAT(input.element_type(), qr_impl, input, q_out, r_out); +} + +XLA_FFI_DEFINE_HANDLER(qr_handler, do_qr, + Ffi::Bind() + .Arg() // input matrix (column-major) + .Ret() // Q output (column-major) + .Ret()); // R output (column-major) + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_qr_handler() { + return R_MakeExternalPtr((void *)rpjrt::qr_handler, R_NilValue, R_NilValue); +} diff --git a/src/qr_cuda.cpp b/src/qr_cuda.cpp new file mode 100644 index 00000000..40535e74 --- /dev/null +++ b/src/qr_cuda.cpp @@ -0,0 +1,147 @@ +// CUDA QR decomposition via cuSOLVER. Mirrors src/qr.cpp on the GPU. +// +// The shared dlopen loader, DeviceMem RAII, per-stream HandleGuard, and the +// `Solver` prologue (handle + devInfo) live in ffi_cusolver.h/.cpp; this file +// only contains the QR algorithm itself. +// +// On Windows the handler is still defined but always returns Unimplemented +// -- pjrt has no CUDA support on Windows, but we keep the symbol so the +// Rcpp::export wrapper resolves cleanly without `#ifdef`s. +#include + +#include "ffi_common.h" + +#ifndef _WIN32 +#include "ffi_cusolver.h" + +#include +#include +#endif + +using namespace xla::ffi; + +namespace rpjrt { + +#ifndef _WIN32 +template +static Error qr_cuda_impl(void *stream, AnyBuffer input, + Result q_out, Result r_out) { + Solver solver(get_gpu_libs()); + PJRT_RETURN_IF_ERROR(solver.begin(stream)); + auto &g = solver.g; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + int k = std::min(m, n); + + auto input_ptr = reinterpret_cast(input.untyped_data()); + auto q_ptr = reinterpret_cast((*q_out).untyped_data()); + auto r_ptr = reinterpret_cast((*r_out).untyped_data()); + + // Cast to size_t before multiplying to avoid int overflow for large + // matrices (mirrors jaxlib's int64_t widening of stride math). + std::size_t a_bytes = static_cast(m) * n * sizeof(T); + std::size_t r_bytes = static_cast(k) * n * sizeof(T); + std::size_t q_bytes = static_cast(m) * k * sizeof(T); + std::size_t tau_bytes = static_cast(k) * sizeof(T); + + DeviceMem d_a(g), d_tau(g), d_work(g); + PJRT_RETURN_IF_GPU_ERROR(d_a.alloc(a_bytes), "cuMemAlloc (A)"); + PJRT_RETURN_IF_GPU_ERROR(d_tau.alloc(tau_bytes), "cuMemAlloc (tau)"); + + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(d_a.ptr, input_ptr, a_bytes, stream), + "cuMemcpyDtoDAsync (input -> A)"); + + int lwork = 0; + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::geqrf_bs(g)(solver.handle.get(), m, n, + reinterpret_cast(d_a.ptr), m, &lwork), + "cusolverDn?geqrf_bufferSize"); + PJRT_RETURN_IF_ERROR( + allocate_workspace(lwork, "cuMemAlloc (geqrf workspace)", d_work)); + + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::geqrf(g)(solver.handle.get(), m, n, + reinterpret_cast(d_a.ptr), m, + reinterpret_cast(d_tau.ptr), + reinterpret_cast(d_work.ptr), lwork, + reinterpret_cast(solver.info.ptr)), + "cusolverDn?geqrf"); + + // Extract R: zero the output, then copy upper triangular column by column. + PJRT_RETURN_IF_GPU_ERROR(g.memset_d8(r_ptr, 0, r_bytes, stream), + "cuMemsetD8Async (R)"); + for (int j = 0; j < n; j++) { + int elems = std::min(j + 1, k); + std::size_t r_off = static_cast(j) * k * sizeof(T); + std::size_t a_off = static_cast(j) * m * sizeof(T); + PJRT_RETURN_IF_GPU_ERROR( + g.memcpy_dtod(r_ptr + r_off, d_a.ptr + a_off, + static_cast(elems) * sizeof(T), stream), + "cuMemcpyDtoDAsync (R column)"); + } + + // Copy first k columns of factored A to Q output (column-major, so first + // m*k elements), then run orgqr in-place on Q. + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(q_ptr, d_a.ptr, q_bytes, stream), + "cuMemcpyDtoDAsync (A -> Q)"); + + int lwork_orgqr = 0; + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::orgqr_bs(g)(solver.handle.get(), m, k, k, + reinterpret_cast(q_ptr), m, + reinterpret_cast(d_tau.ptr), + &lwork_orgqr), + "cusolverDn?orgqr_bufferSize"); + + // Reuse the geqrf workspace if it's already big enough (saves an alloc + // for the common case where geqrf needs more scratch than orgqr). + DeviceMem d_work2(g); + T *work_ptr; + if (lwork_orgqr <= lwork) { + work_ptr = reinterpret_cast(d_work.ptr); + } else { + PJRT_RETURN_IF_ERROR(allocate_workspace( + lwork_orgqr, "cuMemAlloc (orgqr workspace)", d_work2)); + work_ptr = reinterpret_cast(d_work2.ptr); + } + + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::orgqr(g)(solver.handle.get(), m, k, k, + reinterpret_cast(q_ptr), m, + reinterpret_cast(d_tau.ptr), work_ptr, + lwork_orgqr, + reinterpret_cast(solver.info.ptr)), + "cusolverDn?orgqr"); + + return Error::Success(); +} +#endif // _WIN32 + +static Error do_qr_cuda(void *stream, AnyBuffer input, Result q_out, + Result r_out) { +#ifdef _WIN32 + return Error(ErrorCode::kUnimplemented, + "CUDA QR is not supported on Windows"); +#else + PJRT_DISPATCH_FLOAT(input.element_type(), qr_cuda_impl, stream, input, q_out, + r_out); +#endif +} + +XLA_FFI_DEFINE_HANDLER(qr_handler_cuda, do_qr_cuda, + Ffi::Bind() + .Ctx>() + .Arg() + .Ret() + .Ret()); + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_qr_handler_cuda() { + return R_MakeExternalPtr((void *)rpjrt::qr_handler_cuda, R_NilValue, + R_NilValue); +} diff --git a/src/svd.cpp b/src/svd.cpp new file mode 100644 index 00000000..762e3d84 --- /dev/null +++ b/src/svd.cpp @@ -0,0 +1,93 @@ +// Singular value decomposition via LAPACK gesdd (divide-and-conquer). +// +// Reduced ("economy") SVD: jobz = 'S'. For A of shape (m, n) with k = min(m, n): +// U : (m, k) +// S : (k,) (always non-negative, real) +// Vt : (k, n) +// such that A = U diag(S) Vt. +// +// gesdd needs an integer workspace `iwork` of size 8*k in addition to the +// real workspace queried via lwork = -1. +#include + +#include "ffi_common.h" +#include "ffi_lapack.h" + +#include +#include +#include + +using namespace xla::ffi; + +namespace rpjrt { + +template +static Error svd_impl(AnyBuffer input, Result u_out, + Result s_out, Result vt_out) { + using S = typename Lapack::S; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + int k = std::min(m, n); + + std::vector a(static_cast(m) * n); + const T *in = static_cast(input.untyped_data()); + for (std::size_t i = 0; i < a.size(); i++) + a[i] = static_cast(in[i]); + + std::vector sv(k); + std::vector u(static_cast(m) * k); + std::vector vt(static_cast(k) * n); + + const char jobz = 'S'; + int ldu = m; + int ldvt = k; + int info; + + int lwork = -1; + S work_size; + std::vector iwork(8 * k); + Lapack::gesdd(&jobz, &m, &n, a.data(), &m, sv.data(), u.data(), &ldu, + vt.data(), &ldvt, &work_size, &lwork, iwork.data(), &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "gesdd workspace query")); + + lwork = static_cast(work_size); + std::vector work(lwork); + Lapack::gesdd(&jobz, &m, &n, a.data(), &m, sv.data(), u.data(), &ldu, + vt.data(), &ldvt, work.data(), &lwork, iwork.data(), &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "gesdd")); + + T *u_data = static_cast((*u_out).untyped_data()); + T *s_data = static_cast((*s_out).untyped_data()); + T *vt_data = static_cast((*vt_out).untyped_data()); + for (std::size_t i = 0; i < u.size(); i++) + u_data[i] = static_cast(u[i]); + for (int i = 0; i < k; i++) + s_data[i] = static_cast(sv[i]); + for (std::size_t i = 0; i < vt.size(); i++) + vt_data[i] = static_cast(vt[i]); + + return Error::Success(); +} + +static Error do_svd(AnyBuffer input, Result u_out, + Result s_out, Result vt_out) { + PJRT_DISPATCH_FLOAT(input.element_type(), svd_impl, input, u_out, s_out, + vt_out); +} + +XLA_FFI_DEFINE_HANDLER(svd_handler, do_svd, + Ffi::Bind() + .Arg() // input matrix + .Ret() // U (m, k) + .Ret() // S (k,) + .Ret()); // Vt (k, n) + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_svd_handler() { + return R_MakeExternalPtr((void *)rpjrt::svd_handler, R_NilValue, R_NilValue); +} diff --git a/src/svd_cuda.cpp b/src/svd_cuda.cpp new file mode 100644 index 00000000..e61fb5c9 --- /dev/null +++ b/src/svd_cuda.cpp @@ -0,0 +1,113 @@ +// CUDA SVD via cuSOLVER gesvd. +// +// cuSOLVER's gesvd requires m >= n. For the m < n case the user can call +// nv_svd on the transpose and swap U <-> V; we surface a clear +// InvalidArgument error rather than silently doing this. JAX's older +// cuSOLVER path has the same restriction. +// +// jobu / jobvt are passed as 'S' (reduced); for m >= n that gives: +// U : (m, n) -- ldu = m +// S : (n,) +// Vt : (n, n) -- ldvt = n +// matching the host gesdd output shapes when k = n. +#include + +#include "ffi_common.h" + +#ifndef _WIN32 +#include "ffi_cusolver.h" + +#include +#endif + +using namespace xla::ffi; + +namespace rpjrt { + +#ifndef _WIN32 +template +static Error svd_cuda_impl(void *stream, AnyBuffer input, + Result u_out, Result s_out, + Result vt_out) { + Solver solver(get_gpu_libs()); + PJRT_RETURN_IF_ERROR(solver.begin(stream)); + auto &g = solver.g; + + auto dims = input.dimensions(); + int m, n; + PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m)); + PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n)); + if (m < n) { + return Error::InvalidArgument( + "CUDA SVD requires m >= n; transpose the input and swap U<->V " + "for the wide case"); + } + + auto input_ptr = reinterpret_cast(input.untyped_data()); + auto u_ptr = reinterpret_cast((*u_out).untyped_data()); + auto s_ptr = reinterpret_cast((*s_out).untyped_data()); + auto vt_ptr = reinterpret_cast((*vt_out).untyped_data()); + + std::size_t a_bytes = static_cast(m) * n * sizeof(T); + + // gesvd overwrites A. Allocate a working copy so the input buffer is + // preserved (XLA may have aliased it elsewhere). + DeviceMem d_a(g); + PJRT_RETURN_IF_GPU_ERROR(d_a.alloc(a_bytes), "cuMemAlloc (A)"); + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(d_a.ptr, input_ptr, a_bytes, stream), + "cuMemcpyDtoDAsync (input -> A)"); + + int lwork = 0; + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::gesvd_bs(g)(solver.handle.get(), m, n, &lwork), + "cusolverDn?gesvd_bufferSize"); + + DeviceMem d_work(g); + PJRT_RETURN_IF_ERROR( + allocate_workspace(lwork, "cuMemAlloc (gesvd workspace)", d_work)); + + // jobu / jobvt are 'S' (reduced). They're typed as signed char in the + // cuSOLVER ABI; passing the literal char works because of integer + // promotion at the call site. + PJRT_RETURN_IF_GPU_ERROR( + CuSolver::gesvd(g)(solver.handle.get(), 'S', 'S', m, n, + reinterpret_cast(d_a.ptr), m, + reinterpret_cast(s_ptr), + reinterpret_cast(u_ptr), m, + reinterpret_cast(vt_ptr), n, + reinterpret_cast(d_work.ptr), lwork, + /*rwork=*/nullptr, + reinterpret_cast(solver.info.ptr)), + "cusolverDn?gesvd"); + + return Error::Success(); +} +#endif // _WIN32 + +static Error do_svd_cuda(void *stream, AnyBuffer input, + Result u_out, Result s_out, + Result vt_out) { +#ifdef _WIN32 + return Error(ErrorCode::kUnimplemented, + "CUDA SVD is not supported on Windows"); +#else + PJRT_DISPATCH_FLOAT(input.element_type(), svd_cuda_impl, stream, input, u_out, + s_out, vt_out); +#endif +} + +XLA_FFI_DEFINE_HANDLER(svd_handler_cuda, do_svd_cuda, + Ffi::Bind() + .Ctx>() + .Arg() + .Ret() + .Ret() + .Ret()); + +} // namespace rpjrt + +// [[Rcpp::export]] +SEXP get_svd_handler_cuda() { + return R_MakeExternalPtr((void *)rpjrt::svd_handler_cuda, R_NilValue, + R_NilValue); +} diff --git a/tests/testthat/helper-linalg.R b/tests/testthat/helper-linalg.R new file mode 100644 index 00000000..bef583b9 --- /dev/null +++ b/tests/testthat/helper-linalg.R @@ -0,0 +1,128 @@ +# Helpers for the linear-algebra custom_call tests in test-linalg.R. +# +# We test the FFI handlers end-to-end by hand-rolling a tiny StableHLO program +# (one custom_call op) for each test, compiling it, and executing it. This +# keeps pjrt's tests self-contained -- no dependency on the stablehlo R +# package or anvl. +# +# All custom_calls in this file pass `operand_layouts` and `result_layouts` +# in column-major order (`dense<[0, 1]>` for 2D, `dense<[0]>` for 1D). The +# LAPACK/cuSOLVER handlers expect column-major data, and XLA does the +# row-major <-> column-major conversion transparently when these layout +# attributes are set. + +mlir_dtype <- function(dtype) { + switch(dtype, f32 = "f32", f64 = "f64", i32 = "i32", + stop("unsupported dtype: ", dtype)) +} + +# Layout attribute string for an n-D column-major tensor: dense<[0, 1, ...]>. +col_major_layout <- function(ndim) { + dims <- paste(seq_len(ndim) - 1L, collapse = ", ") + sprintf("dense<[%s]> : tensor<%dxindex>", dims, ndim) +} + +# Build the MLIR text for a single-input, multi-output custom_call program. +# `out_specs` is a list of list(dims, dtype) for each result. +build_program <- function(target, in_dims, in_dtype, out_specs) { + in_type <- sprintf( + "tensor<%sx%s>", + paste(in_dims, collapse = "x"), + mlir_dtype(in_dtype) + ) + out_types <- vapply(out_specs, function(s) { + sprintf("tensor<%sx%s>", + paste(s$dims, collapse = "x"), + mlir_dtype(s$dtype)) + }, character(1)) + out_layouts <- vapply(out_specs, function(s) col_major_layout(length(s$dims)), + character(1)) + ret_names <- paste0("%out", seq_along(out_specs)) + + sprintf( + 'func.func @main(%%a: %s) -> (%s) { + %s = stablehlo.custom_call @%s(%%a) { + call_target_name = "%s", + api_version = 4 : i32, + operand_layouts = [%s], + result_layouts = [%s] + } : (%s) -> (%s) + func.return %s : %s +}', + in_type, + paste(out_types, collapse = ", "), + paste(ret_names, collapse = ", "), + target, target, + col_major_layout(length(in_dims)), + paste(out_layouts, collapse = ", "), + in_type, + paste(out_types, collapse = ", "), + paste(ret_names, collapse = ", "), + paste(out_types, collapse = ", ") + ) +} + +# Compile a custom_call program once, run on `a`, return the outputs as +# a list of base-R arrays/vectors. +# +# `dtype` is the dtype of the input AND of any floating-point output. Pass +# explicitly when you want to test the f32 path with R-double input data +# (R has no native f32). Defaults to f64 for double inputs, i32 for int. +run_linalg <- function(target, a, out_specs, dtype = NULL) { + in_dims <- dim(a) + if (is.null(in_dims)) in_dims <- length(a) + if (is.null(dtype)) { + dtype <- if (is.double(a)) "f64" else if (is.integer(a)) "i32" else "f32" + } + + program <- pjrt_program(build_program(target, in_dims, dtype, out_specs)) + program <- pjrt_compile(program) + outs <- pjrt_execute(program, pjrt_buffer(a, dtype = dtype)) + if (!is.list(outs)) outs <- list(outs) + lapply(outs, as_array) +} + +run_qr <- function(a) { + m <- nrow(a); n <- ncol(a); k <- min(m, n) + dtype <- if (is.double(a)) "f64" else "f32" + run_linalg("qr", a, list( + list(dims = c(m, k), dtype = dtype), + list(dims = c(k, n), dtype = dtype) + )) +} + +run_lu <- function(a) { + m <- nrow(a); n <- ncol(a); k <- min(m, n) + dtype <- if (is.double(a)) "f64" else "f32" + run_linalg("lu", a, list( + list(dims = c(m, n), dtype = dtype), + list(dims = k, dtype = "i32") + )) +} + +run_svd <- function(a) { + m <- nrow(a); n <- ncol(a); k <- min(m, n) + dtype <- if (is.double(a)) "f64" else "f32" + run_linalg("svd", a, list( + list(dims = c(m, k), dtype = dtype), + list(dims = k, dtype = dtype), + list(dims = c(k, n), dtype = dtype) + )) +} + +run_eigh <- function(a) { + n <- nrow(a) + dtype <- if (is.double(a)) "f64" else "f32" + run_linalg("eigh", a, list( + list(dims = c(n, n), dtype = dtype), + list(dims = n, dtype = dtype) + )) +} + +# Tolerance helpers. f32 LAPACK tends to round to ~1e-5; f64 to ~1e-10. +# We multiply by sqrt(min(m, n)) to give larger problems a bit more slack +# (roundoff accumulates with problem size). +linalg_tol <- function(a, scale = 1) { + base <- if (is.double(a)) 1e-10 else 1e-4 + scale * base * max(1, sqrt(min(dim(a)))) +} diff --git a/tests/testthat/test-linalg.R b/tests/testthat/test-linalg.R new file mode 100644 index 00000000..a1af2309 --- /dev/null +++ b/tests/testthat/test-linalg.R @@ -0,0 +1,385 @@ +# End-to-end property tests for the built-in LAPACK / cuSOLVER custom_call +# handlers (`qr`, `lu`, `svd`, `eigh`). +# +# Each handler is exercised through a small JIT-compiled stablehlo.custom_call +# program (see helper-linalg.R for the harness). We don't depend on anvl -- +# pjrt is the layer that owns these handlers, so the tests live here. +# +# Coverage strategy: a small set of "shape suites" generated for each op +# (square, tall, wide where the op supports it; 1x1; an identity matrix) +# crossed with f32 and f64. For each generated input we verify the +# defining property of the factorisation (reconstruction, orthogonality, +# triangularity, etc.) and -- where possible -- compare against base R's +# `qr()` / `solve()` / `svd()` / `eigen()`. + +skip_if_metal("linalg custom_calls are CPU/CUDA only") + +# Always test on the host (LAPACK) backend; the CUDA path is exercised when +# PJRT_PLATFORM=cuda is set in the environment. The CUDA SVD has an m >= n +# requirement (cuSOLVER's gesvd); we skip the wide-SVD tests on CUDA below. + +# --------------------------------------------------------------------------- +# QR +# --------------------------------------------------------------------------- + +qr_check <- function(a, label) { + res <- run_qr(a) + q <- res[[1L]]; r <- res[[2L]] + m <- nrow(a); n <- ncol(a); k <- min(m, n) + + expect_equal(dim(q), c(m, k), info = label) + expect_equal(dim(r), c(k, n), info = label) + + tol <- linalg_tol(a) + # Reconstruction + expect_equal(q %*% r, a, tolerance = tol, info = paste(label, "QR=A")) + # Q has orthonormal columns + expect_equal(t(q) %*% q, diag(k), tolerance = tol, + info = paste(label, "Q^T Q = I")) + # R is upper triangular: the leading k x k block has zeros below the diagonal. + if (k > 1L) { + R_block <- r[seq_len(k), seq_len(k), drop = FALSE] + expect_lt(max(abs(R_block[lower.tri(R_block)])), tol * 10, + label = paste(label, "R upper-triangular")) + } +} + +test_that("qr: square / tall / wide / 1x1 / identity (f64)", { + set.seed(1) + cases <- list( + list(label = "1x1", a = matrix(2.5, 1, 1)), + list(label = "2x2", a = matrix(c(1, 2, 3, 4), nrow = 2)), + list(label = "3x3 random", a = matrix(rnorm(9), nrow = 3)), + list(label = "5x5 random", a = matrix(rnorm(25), nrow = 5)), + list(label = "tall 7x3", a = matrix(rnorm(21), nrow = 7)), + list(label = "wide 3x7", a = matrix(rnorm(21), nrow = 3)), + list(label = "identity 4x4", a = diag(4)), + list(label = "20x20 random", a = matrix(rnorm(400), nrow = 20)) + ) + for (cs in cases) qr_check(cs$a, cs$label) +}) + +test_that("qr: f32 path", { + set.seed(2) + a <- matrix(rnorm(12), nrow = 4) + res <- run_linalg( + "qr", a, + list(list(dims = c(4, 3), dtype = "f32"), + list(dims = c(3, 3), dtype = "f32")), + dtype = "f32" + ) + q <- res[[1L]]; r <- res[[2L]] + expect_equal(q %*% r, a, tolerance = 1e-4) + expect_equal(t(q) %*% q, diag(3), tolerance = 1e-4) +}) + +test_that("qr: agrees with base::qr() up to column signs", { + # base::qr returns a packed factorization but we can extract Q, R via + # qr.Q / qr.R for direct comparison. Both LAPACK paths use the same + # algorithm, so the result is unique up to the signs of R's diagonal. + set.seed(3) + for (dims in list(c(5, 3), c(4, 4), c(3, 5))) { + a <- matrix(rnorm(prod(dims)), nrow = dims[1L]) + base_qr <- qr(a) + Q_base <- qr.Q(base_qr); R_base <- qr.R(base_qr) + res <- run_qr(a) + Q <- res[[1L]]; R <- res[[2L]] + # Resolve sign ambiguity: align signs of diagonals of R. + k <- min(dims) + sgn <- sign(diag(R)[seq_len(k)]) + sgn_base <- sign(diag(R_base)[seq_len(k)]) + flip <- sgn * sgn_base + R <- diag(flip, k) %*% R + Q <- Q %*% diag(flip, k) + expect_equal(R, R_base, tolerance = 1e-10) + expect_equal(Q, Q_base, tolerance = 1e-10) + } +}) + +test_that("qr: input buffer is not modified", { + a <- matrix(c(1, 2, 3, 4, 5, 6), nrow = 3) + a0 <- a + run_qr(a) + expect_identical(a, a0) +}) + +# --------------------------------------------------------------------------- +# LU +# --------------------------------------------------------------------------- + +apply_lu_pivots <- function(perm_init, pivots) { + perm <- perm_init + for (i in seq_along(pivots)) { + j <- pivots[[i]] + if (j != i) perm[c(i, j)] <- perm[c(j, i)] + } + perm +} + +lu_check <- function(a, label) { + res <- run_lu(a) + LU <- res[[1L]] + pivots <- as.integer(res[[2L]]) + m <- nrow(a); n <- ncol(a); k <- min(m, n) + + expect_equal(dim(LU), c(m, n), info = label) + expect_equal(length(pivots), k, info = label) + expect_true(all(pivots >= 1L & pivots <= m), + info = paste(label, "pivots in [1, m]")) + + # Reconstruct L (m x k, unit lower-triangular) and U (k x n, upper). + L <- matrix(0, nrow = m, ncol = k) + diag(L) <- 1 + for (j in seq_len(k)) { + if (j + 1L <= m) { + L[(j + 1L):m, j] <- LU[(j + 1L):m, j] + } + } + U <- matrix(0, nrow = k, ncol = n) + for (i in seq_len(k)) { + U[i, i:n] <- LU[i, i:n] + } + PA <- L %*% U + perm <- apply_lu_pivots(seq_len(m), pivots) + expect_equal(PA[order(perm), , drop = FALSE], a, + tolerance = linalg_tol(a), + info = paste(label, "P^-1 L U = A")) +} + +test_that("lu: square / tall / wide / 1x1 / identity / forced pivot", { + set.seed(11) + lu_check(matrix(2.5, 1, 1), "1x1") + lu_check(matrix(c(4, 3, 6, 3), nrow = 2), "2x2 no pivot") + # First row has zero in column 1, so getrf must swap. + lu_check(matrix(c(0, 1, 1, 1), nrow = 2), "2x2 forced pivot") + lu_check(matrix(rnorm(9), nrow = 3), "3x3 random") + lu_check(matrix(rnorm(25), nrow = 5), "5x5 random") + lu_check(matrix(rnorm(21), nrow = 7), "tall 7x3") + lu_check(matrix(rnorm(21), nrow = 3), "wide 3x7") + lu_check(diag(4), "identity 4x4") + lu_check(matrix(rnorm(900), nrow = 30), "30x30 random") +}) + +test_that("lu: f32 reconstruction", { + set.seed(12) + a <- matrix(rnorm(16), nrow = 4) + res <- run_linalg( + "lu", a, + list(list(dims = c(4, 4), dtype = "f32"), + list(dims = 4, dtype = "i32")), + dtype = "f32" + ) + LU <- res[[1L]] + pivots <- as.integer(res[[2L]]) + L <- diag(4) + L[lower.tri(L)] <- LU[lower.tri(LU)] + U <- LU + U[lower.tri(U)] <- 0 + perm <- apply_lu_pivots(seq_len(4), pivots) + expect_equal((L %*% U)[order(perm), ], a, tolerance = 1e-4) +}) + +test_that("lu: pivots are int32 (validates result_layouts on a non-float output)", { + res <- run_lu(matrix(c(0, 1, 1, 1), nrow = 2)) + expect_true(is.integer(as.vector(res[[2L]]))) +}) + +# --------------------------------------------------------------------------- +# SVD +# --------------------------------------------------------------------------- + +svd_check <- function(a, label) { + res <- run_svd(a) + U <- res[[1L]]; S <- as.numeric(res[[2L]]); Vt <- res[[3L]] + m <- nrow(a); n <- ncol(a); k <- min(m, n) + + expect_equal(dim(U), c(m, k), info = label) + expect_equal(length(S), k, info = label) + expect_equal(dim(Vt), c(k, n), info = label) + + tol <- linalg_tol(a) + # Singular values: non-negative, descending + expect_true(all(S >= -tol), + info = paste(label, "S >= 0")) + if (length(S) > 1L) { + expect_true(all(diff(S) <= tol), + info = paste(label, "S descending")) + } + + # U has orthonormal columns + expect_equal(t(U) %*% U, diag(k), tolerance = tol, + info = paste(label, "U^T U = I")) + # V (= t(Vt)) has orthonormal columns + expect_equal(Vt %*% t(Vt), diag(k), tolerance = tol, + info = paste(label, "V^T V = I")) + # Reconstruction + Sd <- if (length(S) == 1L) matrix(S) else diag(S) + expect_equal(U %*% Sd %*% Vt, a, tolerance = tol, + info = paste(label, "U S Vt = A")) + + # Singular values match base::svd + base_S <- svd(a)$d + expect_equal(S, base_S, tolerance = tol, + info = paste(label, "S matches base::svd")) +} + +test_that("svd: square / tall / wide / 1x1 / identity (host)", { + skip_if(is_cuda(), "wide-SVD on CUDA is unsupported (cuSOLVER gesvd: m >= n)") + set.seed(21) + svd_check(matrix(2.5, 1, 1), "1x1") + svd_check(matrix(c(1, 0, 0, 1), nrow = 2), "identity 2x2") + svd_check(matrix(rnorm(9), nrow = 3), "3x3 random") + svd_check(matrix(rnorm(25), nrow = 5), "5x5 random") + svd_check(matrix(rnorm(21), nrow = 7), "tall 7x3") + svd_check(matrix(rnorm(21), nrow = 3), "wide 3x7 (host)") + svd_check(matrix(rnorm(400), nrow = 20), "20x20 random") +}) + +test_that("svd: square / tall / 1x1 (CUDA - no wide)", { + skip_if(!is_cuda()) + set.seed(22) + svd_check(matrix(2.5, 1, 1), "1x1") + svd_check(matrix(rnorm(9), nrow = 3), "3x3 random") + svd_check(matrix(rnorm(21), nrow = 7), "tall 7x3") +}) + +test_that("svd: CUDA rejects m < n", { + skip_if(!is_cuda()) + a <- matrix(rnorm(6), nrow = 2) # 2x3, m < n + expect_error(run_svd(a), "m >= n") +}) + +test_that("svd: f32 reconstruction", { + set.seed(23) + a <- matrix(rnorm(20), nrow = 5) + res <- run_linalg( + "svd", a, + list(list(dims = c(5, 4), dtype = "f32"), + list(dims = 4, dtype = "f32"), + list(dims = c(4, 4), dtype = "f32")), + dtype = "f32" + ) + U <- res[[1L]]; S <- as.numeric(res[[2L]]); Vt <- res[[3L]] + expect_equal(U %*% diag(S) %*% Vt, a, tolerance = 1e-4) +}) + +test_that("svd: ill-conditioned input still factorises correctly", { + # Diagonal matrix with widely varying singular values. + set.seed(24) + s_true <- 10 ^ seq(0, -10, length.out = 5) # 1, 1e-2.5, ..., 1e-10 + a <- diag(s_true) + res <- run_svd(a) + S <- as.numeric(res[[2L]]) + # Check the singular values come back in the right order, with high + # relative accuracy on the dominant ones. + expect_equal(S, sort(s_true, decreasing = TRUE), tolerance = 1e-9) +}) + +# --------------------------------------------------------------------------- +# eigh +# --------------------------------------------------------------------------- + +random_symmetric <- function(n) { + m <- matrix(rnorm(n * n), n, n) + (m + t(m)) / 2 +} +random_spd <- function(n) { + m <- matrix(rnorm(n * n), n, n) + m %*% t(m) + diag(n) * 0.5 # add small ridge +} + +eigh_check <- function(a, label, ascending = TRUE) { + res <- run_eigh(a) + V <- res[[1L]]; W <- as.numeric(res[[2L]]) + n <- nrow(a) + + expect_equal(dim(V), c(n, n), info = label) + expect_equal(length(W), n, info = label) + + tol <- linalg_tol(a, scale = 5) + # Eigenvalues sorted ascending. + if (n > 1L && ascending) { + expect_true(all(diff(W) >= -tol), + info = paste(label, "W ascending")) + } + # Orthonormal eigenvectors. + expect_equal(t(V) %*% V, diag(n), tolerance = tol, + info = paste(label, "V^T V = I")) + # Reconstruction (symmetric A; we only fed the lower triangle but the + # input was symmetric so we can compare against the full matrix). + Wd <- if (n == 1L) matrix(W) else diag(W) + expect_equal(V %*% Wd %*% t(V), a, tolerance = tol, + info = paste(label, "V W V^T = A")) +} + +test_that("eigh: 1x1 / 2x2 / random symmetric of varying sizes", { + set.seed(31) + eigh_check(matrix(2.5, 1, 1), "1x1") + eigh_check(matrix(c(2, 1, 1, 2), nrow = 2), "2x2 known") + eigh_check(random_symmetric(5), "5x5 symmetric") + eigh_check(random_spd(5), "5x5 SPD") + eigh_check(random_symmetric(10), "10x10 symmetric") + eigh_check(random_spd(20), "20x20 SPD") + # Identity: all eigenvalues 1, eigenvectors are any orthonormal basis. + res <- run_eigh(diag(4)) + expect_equal(as.numeric(res[[2L]]), c(1, 1, 1, 1), tolerance = 1e-12) +}) + +test_that("eigh: 2x2 with known eigenvalues", { + # [[2, 1], [1, 2]] has eigenvalues {1, 3} (ascending). + res <- run_eigh(matrix(c(2, 1, 1, 2), nrow = 2)) + expect_equal(as.numeric(res[[2L]]), c(1, 3), tolerance = 1e-12) +}) + +test_that("eigh: matches base::eigen on symmetric input", { + set.seed(32) + for (n in c(3L, 6L, 10L)) { + a <- random_symmetric(n) + base_e <- eigen(a, symmetric = TRUE) # base::eigen returns descending + res <- run_eigh(a) + W <- as.numeric(res[[2L]]) + # Reverse W to match base::eigen ordering. + expect_equal(rev(W), base_e$values, tolerance = 1e-10, + info = paste0("n=", n)) + } +}) + +test_that("eigh: f32 reconstruction", { + set.seed(33) + a <- random_spd(6) + res <- run_linalg( + "eigh", a, + list(list(dims = c(6, 6), dtype = "f32"), + list(dims = 6, dtype = "f32")), + dtype = "f32" + ) + V <- res[[1L]]; W <- as.numeric(res[[2L]]) + expect_equal(V %*% diag(W) %*% t(V), a, tolerance = 1e-4) +}) + +test_that("eigh: rejects non-square input", { + expect_error(run_eigh(matrix(rnorm(6), nrow = 2)), "square") +}) + +# --------------------------------------------------------------------------- +# Cross-cutting concerns +# --------------------------------------------------------------------------- + +test_that("all four built-in linalg handlers are registered", { + registered <- names(the[["custom_calls"]]) + expect_setequal(intersect(c("qr", "lu", "svd", "eigh"), registered), + c("qr", "lu", "svd", "eigh")) +}) + +test_that("repeated calls reuse cuSOLVER handles (no leak / no recreate)", { + # On the host this is essentially a smoke test; on CUDA it exercises the + # SolverHandlePool's borrow-and-return path many times in a row. If the + # pool were broken (handle dropped, double-released, etc.) we'd see + # cuSOLVER status errors before long. + set.seed(99) + a <- matrix(rnorm(25), nrow = 5) + for (i in 1:20) { + res <- run_qr(a) + expect_equal(res[[1L]] %*% res[[2L]], a, tolerance = 1e-10) + } +}) diff --git a/vignettes/articles/custom-calls-lapack-cusolver.Rmd b/vignettes/articles/custom-calls-lapack-cusolver.Rmd new file mode 100644 index 00000000..28ff8b4d --- /dev/null +++ b/vignettes/articles/custom-calls-lapack-cusolver.Rmd @@ -0,0 +1,589 @@ +--- +title: "Adding custom calls backed by LAPACK and cuSOLVER" +output: rmarkdown::html_vignette +vignette: > + %\VignetteIndexEntry{Adding custom calls backed by LAPACK and cuSOLVER} + %\VignetteEngine{knitr::rmarkdown} + %\VignetteEncoding{UTF-8} +--- + + +Some operations -- QR, LU, Cholesky, SVD, eigendecomposition -- cannot be +expressed directly in StableHLO. pjrt provides them as **built-in custom +calls**: C++ FFI handlers registered with the runtime under well-known +target names (`qr`, `lu`, `svd`, `eigh`). Downstream packages (anvl, future +bindings) invoke them via `stablehlo.custom_call @(...)` without +having to ship their own LAPACK linkage. + +This document describes how to add a new built-in custom call with two +backends: a host (LAPACK) implementation and a CUDA (cuSOLVER) +implementation. The existing implementations (`src/qr.cpp`+`qr_cuda.cpp`, +`src/lu.cpp`+`lu_cuda.cpp`, `src/svd.cpp`+`svd_cuda.cpp`, +`src/eigh.cpp`+`eigh_cuda.cpp`) plus the shared FFI kit +(`src/ffi_common.h`, `src/ffi_lapack.h`, `src/ffi_cusolver.{h,cpp}`) are +the reference; this doc generalises them. + +--- + +## 1. End-to-end wiring + +``` + downstream package (e.g. anvl) + │ + │ emits stablehlo.custom_call @foo(operand) {...} + ▼ + StableHLO program with a custom_call op + │ + │ XLA compiler keeps custom_call as-is, links it to a handler + ▼ + PJRT runtime + │ + │ looks up "foo" in the registered-handlers table + ▼ + FFI handler (XLA_FFI_DEFINE_HANDLER) in src/foo.cpp / src/foo_cuda.cpp + │ + │ calls LAPACK or cuSOLVER, writes outputs to provided buffers + ▼ + Outputs returned to the caller +``` + +Handlers are registered from pjrt's own `R/zzz.R` (alongside `print_tensor`) +via `pjrt_register_custom_call` with a per-platform map (`host`, `cuda`). +The `call_target_name` used by the downstream package (e.g. `"qr"`) must +match the name passed to the registration call. + +--- + +## 2. Files you will touch + +For a new built-in custom call `foo`: + +| File | What goes here | +|---|---| +| `src/ffi_lapack.h` | LAPACK extern decl for the new routine + `Lapack::foo` method on each precision specialisation | +| `src/foo.cpp` | Host (LAPACK) FFI handler + `[[Rcpp::export]] get_foo_handler()` | +| `src/ffi_cusolver.h` | New cuSOLVER function pointers in `GpuLibs` + `CuSolver::foo` accessor | +| `src/ffi_cusolver.cpp` | Load the new symbols inside `get_gpu_libs()` | +| `src/foo_cuda.cpp` | CUDA (cuSOLVER) FFI handler + `[[Rcpp::export]] get_foo_handler_cuda()` | +| `R/zzz.R` | Register the handler pair with PJRT inside `.onLoad` | +| `tests/testthat/test-linalg.R` | Property tests against base R / `solve()` / `eigen()` etc. | + +The shared FFI kit -- `src/ffi_common.h`, `src/ffi_lapack.h`, +`src/ffi_cusolver.{h,cpp}` -- holds the pieces every kernel uses (status +macros, dtype dispatch, `dim_to_int`, the LAPACK promotion trait, the +cuSOLVER handle pool and `DeviceMem` RAII). See section 7 below. + +`src/Makevars.in` adds `$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)` to the link +line; this works on Linux/macOS/Windows out of the box. The CUDA path uses +`dlopen` and needs no extra link flags. + +Don't forget to run `Rscript -e 'Rcpp::compileAttributes(".")'` after +adding new `[[Rcpp::export]]` getters. + +--- + +## 3. The FFI handler signature + +XLA's FFI binds typed contexts and buffers. A typical handler with one input +matrix and two outputs looks like this: + +```c++ +#include "xla/ffi/api/ffi.h" +using namespace xla::ffi; + +static Error do_foo(AnyBuffer input, + Result out_a, + Result out_b) { + // ... + return Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER(foo_handler, do_foo, + Ffi::Bind() + .Arg() // input + .Ret() // first output + .Ret()); // second output + +extern "C" { +void *get_foo_handler(void) { return (void *)foo_handler; } +} +``` + +For CUDA you also bind the platform stream so cuSOLVER can be told where to +launch: + +```c++ +XLA_FFI_DEFINE_HANDLER(foo_handler_cuda, do_foo_cuda, + Ffi::Bind() + .Ctx>() + .Arg() + .Ret() + .Ret()); +``` + +`AnyBuffer` exposes `untyped_data()`, `element_type()`, `dimensions()`, and +`size_bytes()`. `Result` wraps an output buffer; dereference with +`(*out).untyped_data()`. + +--- + +## 4. Buffer layout: column-major + +LAPACK and cuSOLVER both use **column-major** (Fortran) layout. XLA defaults +to row-major, so the StableHLO rule must override layouts on the +`hlo_custom_call`: + +```r +stablehlo::hlo_custom_call( + operand, + call_target_name = "foo", + api_version = 4L, + output_types = list(q_type, r_type), + operand_layouts = list(col_major_layout(2L)), + result_layouts = list(col_major_layout(2L), col_major_layout(2L)) +) +``` + +`col_major_layout(ndim)` returns `c(0L, 1L, ..., ndim-1L)` (minor-to-major +ordering). With these layouts in place, XLA materialises the buffers in +column-major order on entry to the handler, and reads them column-major on +exit. **The handler does not transpose**; it reads/writes directly in +LAPACK-native layout. + +If you forget the layout overrides, the data will silently appear transposed +to LAPACK and you will get wrong answers (no error). + +--- + +## 5. The host (LAPACK) backend + +### Execution model + +LAPACK runs synchronously on the host. The handler: + +1. Reads dimensions from `input.dimensions()` (an `int64_t` span). +2. Copies the input to a writable working buffer (LAPACK overwrites in + place). +3. Calls the workspace-query form of the routine to learn `lwork`. +4. Allocates `work` of size `lwork`. +5. Calls the routine for real. +6. Copies/zeroes the relevant slices into the output buffers. + +There is no async / no streams. The handler returns when the work is done. + +### LAPACK calling convention + +LAPACK is Fortran. Every argument is by pointer, names are lowercased with a +trailing underscore, and the leading dimension `lda` must be passed +explicitly: + +```c++ +extern "C" { +void dgeqrf_(const int *m, const int *n, double *a, const int *lda, + double *tau, double *work, const int *lwork, int *info); +} +``` + +The workspace-query idiom: + +```c++ +int lwork = -1; // sentinel: don't compute, just return optimal lwork +double work_size; +int info; +dgeqrf_(&m, &n, a.data(), &m, tau.data(), &work_size, &lwork, &info); +if (info != 0) return Error::Internal("workspace query failed"); + +lwork = static_cast(work_size); +std::vector work(lwork); +dgeqrf_(&m, &n, a.data(), &m, tau.data(), work.data(), &lwork, &info); +if (info != 0) return Error::Internal("geqrf failed"); +``` + +Always check `info` after every call. `info < 0` means an illegal argument +at position `|info|`; `info > 0` is routine-specific (e.g. for `dpotrf`, +the leading minor that's not positive definite). + +### Windows: f32 promotion + +R on Windows ships its own `Rlapack.dll`, which contains **only +double-precision** routines. There is no `sgeqrf_`, `sorgqr_`, etc. On +macOS and Linux, R links against system LAPACK (Accelerate, OpenBLAS, MKL) +which has both precisions. + +The `Lapack` trait in `src/ffi_lapack.h` absorbs this asymmetry. Each +specialisation exposes a typedef `::S` that is the precision actually +handed to the LAPACK call: + +| target | `Lapack::S` | `Lapack::S` | +|---|---|---| +| Linux / macOS | `float` | `double` | +| Windows | `double` | `double` | + +A kernel writes: + +```c++ +using S = typename Lapack::S; +std::vector a(...); // promoted copy of input +Lapack::geqrf(&m, &n, a.data(), ...); +// ... cast outputs back to T at the boundary +``` + +and gets the right behaviour on both platforms with no `#ifdef` in the +kernel body. See `src/qr.cpp` for the full pattern. + +### What can go wrong (host) + +| Failure mode | Mitigation | +|---|---| +| Forgot `operand_layouts` / `result_layouts` -> data appears transposed | Always set both in the StableHLO rule. | +| Skipped workspace query -> `info = -7` (illegal lwork) | Always do the two-call query/run pattern. | +| Wrote results to a stack buffer and `memcpy`'d -> stack overflow on big inputs | Use `std::vector` (heap). | +| `info` not checked -> garbage outputs, no error | Check `info` after every LAPACK call. | +| Used `int64_t` for `m`/`n` -> mismatch with LAPACK's `int` | LAPACK takes `const int *`; cast and check overflow. | +| f32 on Windows -> link error, undefined symbol | Use the `Lapack::S` trait; never call `s*_` directly outside `ffi_lapack.h`. | +| Aliased input/output buffers -> garbled output | Always copy input to a working buffer (LAPACK overwrites in place). | + +--- + +## 6. The CUDA (cuSOLVER) backend + +### Execution model + +cuSOLVER runs **asynchronously** on a CUDA stream supplied by XLA. The +handler: + +1. Resolves the stream from `Ctx>()`. +2. Borrows a stream-bound cuSOLVER handle from a pool. +3. Allocates device memory for working buffers (input copy, `tau`, + `devInfo`, workspace). +4. Issues async D2D memcpys / kernel launches on the stream. +5. Returns immediately. The caller's stream synchronisation makes the work + visible. + +Inputs and outputs are already on the GPU; the handler only sees device +pointers (cast `untyped_data()` to `CUdeviceptr`, an `unsigned long long`). +Do not synchronise. Do not memcpy to/from host unless you have to. + +### Why dlopen instead of linking against CUDA + +`src/ffi_cusolver.cpp` uses `dlopen("libcusolver.so", ...)` and looks up +symbols with `dlsym`. The reasons: + +- The R package can be built without the CUDA SDK headers being present. +- Users who don't have CUDA installed get a clean fallback (the handler + registration in `R/zzz.R` returns `NULL` and CUDA execution silently + isn't available) instead of a hard load-time link error. +- The cuSOLVER SONAME varies by CUDA version (`.so.11` for CUDA 11+12 at + the time of writing). Probe a list of candidates. + +The trade-off: the function-pointer table (`GpuLibs` in `ffi_cusolver.h`) +has to mirror the cuSOLVER and CUDA driver APIs by hand, and an opaque +`CUdeviceptr` typedef stands in for the real one. New ops add their entries +to `GpuLibs` and to the loader in `ffi_cusolver.cpp::get_gpu_libs()`, then +expose them through a `CuSolver` trait specialisation. + +### The cuSOLVER handle is stateful + +A `cusolverDnHandle_t` is bound to a stream by `cusolverDnSetStream`. If +two FFI calls share a handle but issue work on different streams, the +second call's `cusolverDnSetStream` rebinds the handle while the first +call's launches are still in flight, and they may end up on the wrong +stream. **A single shared handle is unsafe under concurrent execution**. + +The fix, mirroring `jaxlib/gpu/solver_handle_pool.cc`: keep a +mutex-guarded free-list of handles **per stream**. Borrow a handle (create +one if the list is empty), call `SetStream`, use it, return it on scope +exit. Because the handle is keyed by stream, the bind is idempotent and +no race exists. See `ffi_cusolver.cpp::SolverHandlePool` and +`borrow_solver_handle`. Kernels just declare a `HandleGuard handle;` and +call `borrow_solver_handle(g, stream, handle)`. + +### devInfo lives on the device + +cuSOLVER routines take a `devInfo` pointer that must be in device memory. +You must allocate `sizeof(int)` of device space and pass it. The value +itself only flags illegal arguments (mirroring LAPACK's `info < 0` case), +which the handler already validates up front via dimension checks. + +`jaxlib`'s `GeqrfImpl`/`OrgqrImpl` allocate the buffer and never read its +value back. We do the same: don't pay for a D2H copy + sync just to check +something we already validated. Keep the allocation, skip the read-back. + +If you implement a routine where `devInfo` carries information beyond +illegal-argument (e.g. `geqrfBatched`'s per-batch status, or `getrf`'s +zero-pivot row), you do need to read it back -- and that requires a D2H +memcpy + stream sync. `lu_cuda.cpp` deliberately does NOT do this; a +singular factor surfaces as numerical garbage downstream rather than a +launch-time error, matching jaxlib's getrf path. + +### Status codes: check every call + +Every CUDA driver call (`cuMemAlloc`, `cuMemcpyDtoDAsync`, `cuMemsetD8Async`, +`cuStreamSynchronize`) and every cuSOLVER call (including the workspace +size queries and `cusolverDnSetStream`) returns an `int` status. Non-zero +means failure. **Drop none of them.** + +The pattern in `ffi_cusolver.h` (used by every kernel): + +```c++ +#define PJRT_RETURN_IF_GPU_ERROR(expr, what) \ + do { \ + int _status = (expr); \ + if (_status != 0) { \ + return Error::Internal(std::string(what) + " failed with status = " \ + + std::to_string(_status)); \ + } \ + } while (0) +``` + +A failed `cuMemAlloc` whose status is dropped leaves the device pointer at +0; the next memcpy then writes through a null device pointer and corrupts +arbitrary device state. There is no debug message, just garbage results. + +### Workspace queries + +cuSOLVER routines have a `_bufferSize` companion (`cusolverDnSgeqrf_bufferSize`, +etc.) that returns the optimal `lwork` for given dimensions. Same pattern +as LAPACK: query, allocate, run. + +If you call `geqrf` and then `orgqr` on the same problem, their workspace +sizes differ. Query both. If `orgqr_lwork <= geqrf_lwork`, reuse the +buffer; otherwise allocate a second one. See `qr_cuda.cpp` for the +two-workspace pattern. + +### Integer overflow on size math + +Dimensions arrive as `int64_t` from `input.dimensions()`, and cuSOLVER's +API takes `int`. For a 50000x50000 matrix, byte offsets like `j * m * +sizeof(T)` overflow `int` long before being widened to `size_t`. + +Two guards, mirroring `jaxlib`'s `MaybeCastNoOverflow`: + +```c++ +static Error dim_to_int(int64_t v, const char *name, int &out) { + if (v < 0 || v > std::numeric_limits::max()) { + return Error::InvalidArgument(std::string(name) + " out of int range"); + } + out = static_cast(v); + return Error::Success(); +} +``` + +And `static_cast(...)` on at least one operand of every byte-size +computation: + +```c++ +size_t a_bytes = static_cast(m) * n * sizeof(T); +``` + +### What can go wrong (CUDA) + +| Failure mode | Mitigation | +|---|---| +| Status code dropped -> silent garbage outputs | `PJRT_RETURN_IF_GPU_ERROR` on every call. | +| Shared handle + per-call `SetStream` -> race when called from multiple streams | Per-stream handle pool with mutex. | +| `int` overflow on `m * n * sizeof(T)` for large matrices | Cast to `size_t` before multiplying. | +| `int64_t` -> `int` truncation on dimension cast | `dim_to_int` with bounds check. | +| Read `devInfo` D2H without a sync -> get the value before the kernel ran | Either don't read it, or `cuStreamSynchronize` first. | +| `cuMemAlloc(0)` from `lwork = 0` | Guard or accept the allocation error. | +| Handler synchronises the stream -> kills async pipelining | Don't call `cuStreamSynchronize` unless you need a host-visible result. | +| Forgot `Ctx>()` in the FFI bind -> no stream available | Add it; cuSOLVER on the default stream serialises with everything else. | +| `dlopen("libcusolver.so")` fails on a runtime-only CUDA install | Fall back through SONAME variants (`.so.11`, etc.) and degrade gracefully. | +| In-place aliasing of input and output buffers | Always copy input to a working device buffer before factorising. | + +--- + +## 7. The shared FFI kit + +Most of the boilerplate above is in shared headers, so a new linalg op is +~80 lines of host kernel + ~80 lines of CUDA kernel rather than the ~250 + +~400 the original QR implementation needed. + +| Header | Owns | +|---|---| +| `src/ffi_common.h` | `PJRT_RETURN_IF_ERROR`, `PJRT_DISPATCH_FLOAT`, `dim_to_int` | +| `src/ffi_lapack.h` | LAPACK Fortran extern decls (geqrf/orgqr/getrf/gesdd/syevd, plus `s*_` on non-Windows). `Lapack` trait whose `::S` is the LAPACK storage type (Windows: always `double`; elsewhere: matches `T`). `lapack_check_info`. | +| `src/ffi_cusolver.h` | `CUdeviceptr` opaque typedef, `PJRT_RETURN_IF_GPU_ERROR`, `GpuLibs` (function-pointer table), `DeviceMem` RAII, `HandleGuard` (per-stream pool borrow), `borrow_solver_handle`, `CuSolver` dispatch trait. | +| `src/ffi_cusolver.cpp`| `get_gpu_libs()` (singleton dlopen loader) and the SolverHandlePool (mutex-guarded free-list keyed by stream). Shared across all CUDA kernels. | + +### What an op looks like with the kit + +The QR host kernel collapses to one template: + +```cpp +template +static Error qr_impl(AnyBuffer input, Result q_out, + Result r_out) { + using S = typename Lapack::S; // f32 on Linux/macOS, double on Windows + // ...dim_to_int, copy input -> std::vector, workspace query, geqrf, orgqr... + Lapack::geqrf(&m, &n, a.data(), &m, tau.data(), work.data(), &lwork, &info); + PJRT_RETURN_IF_ERROR(lapack_check_info(info, "geqrf")); + // ... +} +static Error do_qr(AnyBuffer input, Result q_out, Result r_out) { + PJRT_DISPATCH_FLOAT(input.element_type(), qr_impl, input, q_out, r_out); +} +``` + +No `#ifdef _WIN32` in the kernel body; the trait absorbs it. + +The CUDA kernel is similarly slimmer because `GpuLibs`, `DeviceMem`, +`HandleGuard`, and the pool live in `ffi_cusolver.cpp`: + +```cpp +template +static Error qr_cuda_impl(void *stream, AnyBuffer input, ...) { + auto &g = get_gpu_libs(); + if (!g.loaded) return Error::Internal("CUDA/cuSOLVER libraries not available"); + HandleGuard handle; + PJRT_RETURN_IF_ERROR(borrow_solver_handle(g, stream, handle)); + DeviceMem d_a(g), d_tau(g), d_info(g), d_work(g); + // ...alloc, memcpy_dtod, CuSolver::geqrf(g)(handle.get(), ...) wrapped in + // PJRT_RETURN_IF_GPU_ERROR... +} +``` + +--- + +## 8. Per-op cheatsheet + +### LU (`?getrf` / `cusolverDn?getrf`) + +- No LAPACK workspace argument (pivoting is in-place); cuSOLVER does have a + bufferSize companion. +- Outputs: `LU` (m, n) packed (L strict-lower with implicit unit diagonal, + U on/above), `pivots` (k = min(m, n)) **int32**, 1-based row swaps. +- `info > 0` from getrf means a pivot was zero (singular). We surface that + as an Internal error on the host. On CUDA `devInfo` lives in device + memory and we do not read it back -- a singular factor degrades + numerically downstream rather than raising a stream-time error. (Matches + jaxlib's `LuDecomposition` behaviour.) +- Reference: `src/lu.cpp`, `src/lu_cuda.cpp`. + +### SVD (`?gesdd` host / `cusolverDn?gesvd` CUDA) + +- Host uses `gesdd` (divide-and-conquer): faster than `gesvd` for medium + and large matrices, what jaxlib uses too. Needs `iwork` of size `8*k` + alongside the real workspace. +- `jobz = 'S'`: reduced ("economy") SVD. Outputs are `U (m, k)`, `S (k,)`, + `Vt (k, n)` with `k = min(m, n)`. +- **CUDA limitation: cuSOLVER's `gesvd` requires `m >= n`.** The wide case + (`m < n`) returns `InvalidArgument` -- the user can SVD the transpose + and swap U <-> V. Jax has the same restriction unless you switch to + `gesvdj`/`gesvdr`. The host backend handles any shape. +- `cusolverDn?gesvd_bufferSize` does **not** take `jobu`/`jobvt` -- the + reported workspace is the worst case. Just call it with `(handle, m, n, + &lwork)`. +- `rwork` is unused for real precisions: pass `nullptr`. +- `jobu` / `jobvt` are `signed char` in the cuSOLVER ABI; passing the + literal `'S'` works thanks to integer promotion at the call site. +- Reference: `src/svd.cpp`, `src/svd_cuda.cpp`. + +### Symmetric eigh (`?syevd` / `cusolverDn?syevd`) + +- We always compute eigenvectors (`jobz = 'V'`/`= 1`) and read the lower + triangle (`uplo = 'L'`/`= 0`). For values-only the user can drop V. +- LAPACK `syevd` needs **two** workspaces: the real `work` and an + integer `iwork`. Both are queried with `lwork = -1` / `liwork = -1`, + reading from `work[0]` / `iwork[0]`. Easy to forget the second one. +- cuSOLVER takes `jobz` and `uplo` as `int` enums (`1` and `0` for our + choice) rather than chars. +- Outputs: `V (n, n)` (eigenvectors as columns, orthonormal), `W (n,)` + (eigenvalues in **ascending** order). Note that R's `eigen()` returns + descending order -- if you want to match it, reverse W and the columns + of V. +- Reference: `src/eigh.cpp`, `src/eigh_cuda.cpp`. + +--- + +## 9. Putting it together: a recipe (pjrt side) + +To add a new built-in factorisation `foo`: + +1. **Add LAPACK extern decls** to `src/ffi_lapack.h` (both `d` and `s`, the + latter under `#ifndef _WIN32`) and a `Lapack::foo` method on each + specialisation. No `#ifdef` in the per-op kernel afterwards. + +2. **Write the host kernel** in `src/foo.cpp`: + - `template static Error foo_impl(...)` using + `using S = typename Lapack::S;`. + - Copy input -> `std::vector`, workspace query (if applicable), run, + extract outputs. `PJRT_RETURN_IF_ERROR(lapack_check_info(...))` after + each LAPACK call. + - `static Error do_foo(...) { PJRT_DISPATCH_FLOAT(et, foo_impl, ...); }`. + - `XLA_FFI_DEFINE_HANDLER(foo_handler, do_foo, ...)` inside + `namespace rpjrt`. + - At the bottom of the file, outside `namespace rpjrt`, expose + `// [[Rcpp::export]] SEXP get_foo_handler() { ... }`. + +3. **Add cuSOLVER function pointers** to `GpuLibs` in + `src/ffi_cusolver.h`, then load them in `src/ffi_cusolver.cpp::get_gpu_libs()`. + Add a `CuSolver::foo()` method on each specialisation. + +4. **Write the CUDA kernel** in `src/foo_cuda.cpp`: + - Implementation body under `#ifndef _WIN32`. + - The handler symbol (`foo_handler_cuda`) is **always** defined; the + `do_foo_cuda` dispatcher returns `Error::Unimplemented` on Windows. + This keeps a single `[[Rcpp::export]] get_foo_handler_cuda()` valid + on every platform without `#ifdef`s around the export. + - Borrow a handle: `HandleGuard handle; borrow_solver_handle(g, stream, handle);` + - `DeviceMem` for working buffers + `devInfo`. Don't read `devInfo` back + unless your op uses values beyond illegal-argument flagging. + - Workspace bufferSize, then `d_work.alloc(...)`, then the call. + - Every CUDA / cuSOLVER call wrapped in `PJRT_RETURN_IF_GPU_ERROR`. + - Bind the FFI with `.Ctx>()`. + +5. **Regenerate Rcpp wrappers**: `Rscript -e 'Rcpp::compileAttributes(".")'`. + +6. **Register handlers** in `R/zzz.R` -- the `register_linalg_handler()` + helper wires up host + (optional) CUDA in one line: + ```r + register_linalg_handler("foo", get_foo_handler(), get_foo_handler_cuda()) + ``` + +7. **Tests** in `tests/testthat/test-linalg.R`. We test the FFI handlers + end-to-end through a small JIT-compiled `stablehlo.custom_call` program + built directly in pjrt -- there is no need to depend on anvl. See the + existing tests for the harness and the property-test patterns + (reconstruction, orthogonality, comparison against base R). + +### On the downstream (anvl) side + +Once the handler is registered in pjrt, anvl exposes it by: + +- adding a `p_foo`/`nvl_foo` primitive in `R/primitives.R` whose `infer_fn` + returns the output `AbstractArray`s, +- writing a `p_foo[["stablehlo"]]` rule in `R/rules-stablehlo.R` that + emits `stablehlo::hlo_custom_call(call_target_name = "foo", ...)` with + the right column-major operand/result layouts, +- adding a `nv_foo` user-facing wrapper in `R/api.R`, +- adding a reverse-mode rule in `R/rules-reverse.R` if differentiable, +- adding `nv_foo` to the Linear Algebra section of `_pkgdown.yml`. + +The pjrt registration call name and the StableHLO `call_target_name` must +match exactly. + +--- + +## 10. Reading list + +- `src/ffi_common.h`, `src/ffi_lapack.h`, `src/ffi_cusolver.h{,.cpp}` -- + the shared kit. Read these first. +- `src/qr.cpp` / `src/qr_cuda.cpp` -- reference host + CUDA implementations + using the kit. +- `src/lu.cpp` / `src/lu_cuda.cpp` -- single-call op with int32 secondary + output. +- `src/svd.cpp` / `src/svd_cuda.cpp` -- multi-output op; CUDA m>=n + restriction. +- `src/eigh.cpp` / `src/eigh_cuda.cpp` -- two-workspace LAPACK call, + `int`-enum cuSOLVER. +- `R/rules-stablehlo.R` (search for `p_qr`) -- example custom_call rule + with column-major layouts. +- `R/zzz.R` -- handler registration with platform map. +- `jaxlib/gpu/solver_kernels_ffi.cc` -- JAX's reference for cuSOLVER FFI + patterns; we mirror its `JAX_FFI_RETURN_IF_GPU_ERROR`, + `MaybeCastNoOverflow`, and `SolverHandlePool` idioms. +- `jaxlib/cpu/lapack_kernels.cc` -- JAX's reference for the host + kernel/trait pattern. Their `Lapack` carries function pointers + bound at startup; we use static `static void foo(...)` methods because + the symbol set is fixed at compile time. From f53047d7a098df65817d4f7a882969bca3b5bcf6 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Sat, 2 May 2026 08:10:16 +0000 Subject: [PATCH 2/2] refactor: route CUDA linalg workspace through ffi::ScratchAllocator Replace per-call cuMemAlloc/cuMemFree (via dlopen'd CUDA driver) with xla::ffi::ScratchAllocator for all device-side workspace, working copies, and devInfo. Allocations now come from XLA's BFC pool -- pooled across calls and visible to XLA's memory accounting -- matching jaxlib's solver_kernels_ffi.cc pattern. ffi_cusolver.h/.cpp: drop DeviceMem RAII; drop mem_alloc/mem_free/ stream_sync from GpuLibs (memcpy_dtod and memset_d8 stay). Solver::begin takes ScratchAllocator& and stores devInfo as int*. allocate_workspace helper now wraps scratch.Allocate and translates nullopt to Error. Each *_cuda.cpp: handler binding adds .Ctx() after .Ctx>(); do_*_cuda takes the allocator by value and threads it as a reference into the templated impl. Every DeviceMem replaced with a typed pointer from allocate_workspace. QR's geqrf/orgqr workspace-reuse logic is preserved. vignettes/articles/custom-calls-lapack-cusolver.Rmd: update execution model, dlopen rationale, devInfo, status-codes, failure-mode table, kit description, and the "add a new linalg op" recipe to reflect the ScratchAllocator pattern. Co-Authored-By: Claude Opus 4.7 (1M context) --- src/eigh_cuda.cpp | 28 ++--- src/ffi_cusolver.cpp | 27 ++--- src/ffi_cusolver.h | 66 +++++------ src/lu_cuda.cpp | 25 ++--- src/qr_cuda.cpp | 72 ++++++------ src/svd_cuda.cpp | 41 +++---- .../articles/custom-calls-lapack-cusolver.Rmd | 103 +++++++++++------- 7 files changed, 192 insertions(+), 170 deletions(-) diff --git a/src/eigh_cuda.cpp b/src/eigh_cuda.cpp index ef696889..fd0d86ce 100644 --- a/src/eigh_cuda.cpp +++ b/src/eigh_cuda.cpp @@ -20,10 +20,11 @@ namespace rpjrt { #ifndef _WIN32 template -static Error eigh_cuda_impl(void *stream, AnyBuffer input, - Result v_out, Result w_out) { +static Error eigh_cuda_impl(void *stream, ScratchAllocator &scratch, + AnyBuffer input, Result v_out, + Result w_out) { Solver solver(get_gpu_libs()); - PJRT_RETURN_IF_ERROR(solver.begin(stream)); + PJRT_RETURN_IF_ERROR(solver.begin(scratch, stream)); auto &g = solver.g; auto dims = input.dimensions(); @@ -54,36 +55,37 @@ static Error eigh_cuda_impl(void *stream, AnyBuffer input, reinterpret_cast(w_ptr), &lwork), "cusolverDn?syevd_bufferSize"); - DeviceMem d_work(g); - PJRT_RETURN_IF_ERROR( - allocate_workspace(lwork, "cuMemAlloc (syevd workspace)", d_work)); + T *d_work; + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(lwork), "syevd workspace", d_work)); PJRT_RETURN_IF_GPU_ERROR( CuSolver::syevd(g)(solver.handle.get(), jobz, uplo, n, reinterpret_cast(v_ptr), n, - reinterpret_cast(w_ptr), - reinterpret_cast(d_work.ptr), lwork, - reinterpret_cast(solver.info.ptr)), + reinterpret_cast(w_ptr), d_work, lwork, + solver.info), "cusolverDn?syevd"); return Error::Success(); } #endif // _WIN32 -static Error do_eigh_cuda(void *stream, AnyBuffer input, - Result v_out, Result w_out) { +static Error do_eigh_cuda(void *stream, ScratchAllocator scratch, + AnyBuffer input, Result v_out, + Result w_out) { #ifdef _WIN32 return Error(ErrorCode::kUnimplemented, "CUDA eigh is not supported on Windows"); #else - PJRT_DISPATCH_FLOAT(input.element_type(), eigh_cuda_impl, stream, input, - v_out, w_out); + PJRT_DISPATCH_FLOAT(input.element_type(), eigh_cuda_impl, stream, scratch, + input, v_out, w_out); #endif } XLA_FFI_DEFINE_HANDLER(eigh_handler_cuda, do_eigh_cuda, Ffi::Bind() .Ctx>() + .Ctx() .Arg() .Ret() .Ret()); diff --git a/src/ffi_cusolver.cpp b/src/ffi_cusolver.cpp index f5d5f8b6..0e2dbad9 100644 --- a/src/ffi_cusolver.cpp +++ b/src/ffi_cusolver.cpp @@ -1,7 +1,8 @@ -// Implementation of the shared cuSOLVER infrastructure: dlopen-based loader, -// device-memory RAII, and a per-stream handle pool. All cuSOLVER-backed -// kernels (qr, lu, svd, eigh) use these singletons so they share one set of -// loaded function pointers and one handle pool. +// Implementation of the shared cuSOLVER infrastructure: dlopen-based loader +// and a per-stream handle pool. All cuSOLVER-backed kernels (qr, lu, svd, +// eigh) use these singletons so they share one set of loaded function +// pointers and one handle pool. Device-memory allocation goes through XLA's +// ffi::ScratchAllocator, not via this loader. #include "ffi_cusolver.h" #ifndef _WIN32 @@ -77,25 +78,14 @@ GpuLibs &get_gpu_libs() { g.s_syevd = load_sym(cusolver, "cusolverDnSsyevd"); g.d_syevd = load_sym(cusolver, "cusolverDnDsyevd"); - g.mem_alloc = load_sym(cuda, "cuMemAlloc_v2"); - g.mem_free = load_sym(cuda, "cuMemFree_v2"); g.memcpy_dtod = load_sym(cuda, "cuMemcpyDtoDAsync_v2"); g.memset_d8 = load_sym(cuda, "cuMemsetD8Async"); - g.stream_sync = - load_sym(cuda, "cuStreamSynchronize"); g.loaded = true; return g; } -DeviceMem::~DeviceMem() { - if (ptr) - g.mem_free(ptr); -} - -int DeviceMem::alloc(std::size_t bytes) { return g.mem_alloc(&ptr, bytes); } - // Per-stream cuSOLVER handle pool. // // cuSOLVER handles are not safe to share across streams: cusolverDnSetStream @@ -170,11 +160,14 @@ Error borrow_solver_handle(GpuLibs &g, void *stream, HandleGuard &out) { return Error::Success(); } -Error Solver::begin(void *stream) { +Error Solver::begin(ScratchAllocator &scratch, void *stream) { if (!g.loaded) return Error::Internal("CUDA/cuSOLVER libraries not available"); PJRT_RETURN_IF_ERROR(borrow_solver_handle(g, stream, handle)); - PJRT_RETURN_IF_GPU_ERROR(info.alloc(sizeof(int)), "cuMemAlloc (devInfo)"); + auto p = scratch.Allocate(sizeof(int)); + if (!p.has_value()) + return Error::Internal("scratch allocation failed (devInfo)"); + info = static_cast(*p); return Error::Success(); } diff --git a/src/ffi_cusolver.h b/src/ffi_cusolver.h index 0d5debd7..4280bb38 100644 --- a/src/ffi_cusolver.h +++ b/src/ffi_cusolver.h @@ -3,6 +3,12 @@ // without a CUDA install. Mirrors the role of jaxlib/gpu/solver_kernels_ffi.cc // + jaxlib/gpu/solver_handle_pool.cc, adapted to a runtime-link-only model. // +// Workspace, devInfo, and per-call working buffers are allocated through XLA's +// ffi::ScratchAllocator (BFC-pool-backed, integrated with XLA's memory +// accounting). The dlopen surface only covers cuSOLVER itself plus the +// memcpy/memset/sync helpers cuSOLVER doesn't provide; raw cuMemAlloc/cuMemFree +// are no longer needed. +// // Only the non-Windows half is meaningful; on Windows there is no CUDA, and // dlopen is POSIX-only. #pragma once @@ -88,29 +94,16 @@ struct GpuLibs { int (*d_syevd)(void *, int, int, int, double *, int, double *, double *, int, int *); - // CUDA driver. - int (*mem_alloc)(CUdeviceptr *, std::size_t); - int (*mem_free)(CUdeviceptr); + // CUDA driver helpers. Allocation goes through ffi::ScratchAllocator, but + // memcpy / memset / stream-sync still need driver entry points. int (*memcpy_dtod)(CUdeviceptr, CUdeviceptr, std::size_t, void *); int (*memset_d8)(CUdeviceptr, unsigned char, std::size_t, void *); - int (*stream_sync)(void *); bool loaded = false; }; GpuLibs &get_gpu_libs(); -// RAII wrapper for cuMemAlloc'd device memory. -struct DeviceMem { - CUdeviceptr ptr = 0; - GpuLibs &g; - explicit DeviceMem(GpuLibs &g) : g(g) {} - ~DeviceMem(); - DeviceMem(const DeviceMem &) = delete; - DeviceMem &operator=(const DeviceMem &) = delete; - int alloc(std::size_t bytes); -}; - // Borrowed cuSOLVER handle, returned to the per-stream pool on destruction. class HandleGuard { public: @@ -133,30 +126,39 @@ xla::ffi::Error borrow_solver_handle(GpuLibs &g, void *stream, HandleGuard &out); // Bundled prologue for a CUDA linalg kernel: a borrowed cuSOLVER handle on -// `stream`, plus a pre-allocated device `int` for `devInfo` (every cuSOLVER -// routine wants one). All four built-in linalg kernels open with the same -// three steps -- loaded-check, handle borrow, info alloc -- and `Solver` -// rolls them into one initialiser. `g` and `info` mirror the shape of -// jaxlib's GeqrfImpl prologue (cf. solver_kernels_ffi.cc). +// `stream`, plus a scratch-allocated device `int` for `devInfo` (every +// cuSOLVER routine wants one). All four built-in linalg kernels open with the +// same three steps -- loaded-check, handle borrow, info alloc -- and `Solver` +// rolls them into one initialiser. +// +// `info` is owned by the caller's ScratchAllocator and freed when the FFI +// handler returns. Mirrors jaxlib's GeqrfImpl prologue in +// solver_kernels_ffi.cc (which also threads ScratchAllocator through every +// solver kernel). struct Solver { GpuLibs &g; HandleGuard handle; - DeviceMem info; - explicit Solver(GpuLibs &g) : g(g), info(g) {} + int *info = nullptr; + explicit Solver(GpuLibs &g) : g(g) {} - // Borrow a handle for `stream` and allocate devInfo. Call once per - // kernel invocation, before any cuSOLVER calls. - xla::ffi::Error begin(void *stream); + // Borrow a handle for `stream` and allocate devInfo from `scratch`. Call + // once per kernel invocation, before any cuSOLVER calls. + xla::ffi::Error begin(xla::ffi::ScratchAllocator &scratch, void *stream); }; -// Allocate `lwork * sizeof(T)` bytes of device memory into `out`, with a -// site-name annotation. Centralises the int -> size_t widening so each -// kernel doesn't open-code it per workspace. +// Allocate `n_elements * sizeof(T)` bytes from `scratch`, with a site-name +// annotation. Centralises the size widening and the optional -> Error +// translation so each kernel doesn't open-code it per workspace. template -xla::ffi::Error allocate_workspace(int lwork, const char *name, - DeviceMem &out) { - std::size_t bytes = static_cast(lwork) * sizeof(T); - PJRT_RETURN_IF_GPU_ERROR(out.alloc(bytes), name); +xla::ffi::Error allocate_workspace(xla::ffi::ScratchAllocator &scratch, + std::size_t n_elements, const char *name, + T *&out) { + auto p = scratch.Allocate(n_elements * sizeof(T)); + if (!p.has_value()) { + return xla::ffi::Error::Internal(std::string(name) + + " scratch allocation failed"); + } + out = static_cast(*p); return xla::ffi::Error::Success(); } diff --git a/src/lu_cuda.cpp b/src/lu_cuda.cpp index eebc642d..41c1b46f 100644 --- a/src/lu_cuda.cpp +++ b/src/lu_cuda.cpp @@ -19,11 +19,11 @@ namespace rpjrt { #ifndef _WIN32 template -static Error lu_cuda_impl(void *stream, AnyBuffer input, - Result lu_out, +static Error lu_cuda_impl(void *stream, ScratchAllocator &scratch, + AnyBuffer input, Result lu_out, Result piv_out) { Solver solver(get_gpu_libs()); - PJRT_RETURN_IF_ERROR(solver.begin(stream)); + PJRT_RETURN_IF_ERROR(solver.begin(scratch, stream)); auto &g = solver.g; auto dims = input.dimensions(); @@ -48,16 +48,14 @@ static Error lu_cuda_impl(void *stream, AnyBuffer input, reinterpret_cast(lu_ptr), m, &lwork), "cusolverDn?getrf_bufferSize"); - DeviceMem d_work(g); - PJRT_RETURN_IF_ERROR( - allocate_workspace(lwork, "cuMemAlloc (getrf workspace)", d_work)); + T *d_work; + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(lwork), "getrf workspace", d_work)); PJRT_RETURN_IF_GPU_ERROR( CuSolver::getrf(g)(solver.handle.get(), m, n, - reinterpret_cast(lu_ptr), m, - reinterpret_cast(d_work.ptr), - reinterpret_cast(piv_ptr), - reinterpret_cast(solver.info.ptr)), + reinterpret_cast(lu_ptr), m, d_work, + reinterpret_cast(piv_ptr), solver.info), "cusolverDn?getrf"); // devInfo is intentionally not read back: a singular matrix surfaces as @@ -68,20 +66,21 @@ static Error lu_cuda_impl(void *stream, AnyBuffer input, } #endif // _WIN32 -static Error do_lu_cuda(void *stream, AnyBuffer input, +static Error do_lu_cuda(void *stream, ScratchAllocator scratch, AnyBuffer input, Result lu_out, Result piv_out) { #ifdef _WIN32 return Error(ErrorCode::kUnimplemented, "CUDA LU is not supported on Windows"); #else - PJRT_DISPATCH_FLOAT(input.element_type(), lu_cuda_impl, stream, input, lu_out, - piv_out); + PJRT_DISPATCH_FLOAT(input.element_type(), lu_cuda_impl, stream, scratch, + input, lu_out, piv_out); #endif } XLA_FFI_DEFINE_HANDLER(lu_handler_cuda, do_lu_cuda, Ffi::Bind() .Ctx>() + .Ctx() .Arg() .Ret() .Ret()); diff --git a/src/qr_cuda.cpp b/src/qr_cuda.cpp index 40535e74..14d3a194 100644 --- a/src/qr_cuda.cpp +++ b/src/qr_cuda.cpp @@ -1,7 +1,7 @@ // CUDA QR decomposition via cuSOLVER. Mirrors src/qr.cpp on the GPU. // -// The shared dlopen loader, DeviceMem RAII, per-stream HandleGuard, and the -// `Solver` prologue (handle + devInfo) live in ffi_cusolver.h/.cpp; this file +// The shared dlopen loader, per-stream HandleGuard, and the `Solver` prologue +// (handle + scratch-allocated devInfo) live in ffi_cusolver.h/.cpp; this file // only contains the QR algorithm itself. // // On Windows the handler is still defined but always returns Unimplemented @@ -24,10 +24,11 @@ namespace rpjrt { #ifndef _WIN32 template -static Error qr_cuda_impl(void *stream, AnyBuffer input, - Result q_out, Result r_out) { +static Error qr_cuda_impl(void *stream, ScratchAllocator &scratch, + AnyBuffer input, Result q_out, + Result r_out) { Solver solver(get_gpu_libs()); - PJRT_RETURN_IF_ERROR(solver.begin(stream)); + PJRT_RETURN_IF_ERROR(solver.begin(scratch, stream)); auto &g = solver.g; auto dims = input.dimensions(); @@ -45,95 +46,94 @@ static Error qr_cuda_impl(void *stream, AnyBuffer input, std::size_t a_bytes = static_cast(m) * n * sizeof(T); std::size_t r_bytes = static_cast(k) * n * sizeof(T); std::size_t q_bytes = static_cast(m) * k * sizeof(T); - std::size_t tau_bytes = static_cast(k) * sizeof(T); - DeviceMem d_a(g), d_tau(g), d_work(g); - PJRT_RETURN_IF_GPU_ERROR(d_a.alloc(a_bytes), "cuMemAlloc (A)"); - PJRT_RETURN_IF_GPU_ERROR(d_tau.alloc(tau_bytes), "cuMemAlloc (tau)"); + T *d_a; + T *d_tau; + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(m) * n, "A copy", d_a)); + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(k), "tau", d_tau)); - PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(d_a.ptr, input_ptr, a_bytes, stream), + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(reinterpret_cast(d_a), + input_ptr, a_bytes, stream), "cuMemcpyDtoDAsync (input -> A)"); int lwork = 0; PJRT_RETURN_IF_GPU_ERROR( - CuSolver::geqrf_bs(g)(solver.handle.get(), m, n, - reinterpret_cast(d_a.ptr), m, &lwork), + CuSolver::geqrf_bs(g)(solver.handle.get(), m, n, d_a, m, &lwork), "cusolverDn?geqrf_bufferSize"); - PJRT_RETURN_IF_ERROR( - allocate_workspace(lwork, "cuMemAlloc (geqrf workspace)", d_work)); + + T *d_work; + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(lwork), "geqrf workspace", d_work)); PJRT_RETURN_IF_GPU_ERROR( - CuSolver::geqrf(g)(solver.handle.get(), m, n, - reinterpret_cast(d_a.ptr), m, - reinterpret_cast(d_tau.ptr), - reinterpret_cast(d_work.ptr), lwork, - reinterpret_cast(solver.info.ptr)), + CuSolver::geqrf(g)(solver.handle.get(), m, n, d_a, m, d_tau, d_work, + lwork, solver.info), "cusolverDn?geqrf"); // Extract R: zero the output, then copy upper triangular column by column. PJRT_RETURN_IF_GPU_ERROR(g.memset_d8(r_ptr, 0, r_bytes, stream), "cuMemsetD8Async (R)"); + CUdeviceptr d_a_ptr = reinterpret_cast(d_a); for (int j = 0; j < n; j++) { int elems = std::min(j + 1, k); std::size_t r_off = static_cast(j) * k * sizeof(T); std::size_t a_off = static_cast(j) * m * sizeof(T); PJRT_RETURN_IF_GPU_ERROR( - g.memcpy_dtod(r_ptr + r_off, d_a.ptr + a_off, + g.memcpy_dtod(r_ptr + r_off, d_a_ptr + a_off, static_cast(elems) * sizeof(T), stream), "cuMemcpyDtoDAsync (R column)"); } // Copy first k columns of factored A to Q output (column-major, so first // m*k elements), then run orgqr in-place on Q. - PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(q_ptr, d_a.ptr, q_bytes, stream), + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(q_ptr, d_a_ptr, q_bytes, stream), "cuMemcpyDtoDAsync (A -> Q)"); int lwork_orgqr = 0; PJRT_RETURN_IF_GPU_ERROR( CuSolver::orgqr_bs(g)(solver.handle.get(), m, k, k, - reinterpret_cast(q_ptr), m, - reinterpret_cast(d_tau.ptr), + reinterpret_cast(q_ptr), m, d_tau, &lwork_orgqr), "cusolverDn?orgqr_bufferSize"); // Reuse the geqrf workspace if it's already big enough (saves an alloc // for the common case where geqrf needs more scratch than orgqr). - DeviceMem d_work2(g); T *work_ptr; if (lwork_orgqr <= lwork) { - work_ptr = reinterpret_cast(d_work.ptr); + work_ptr = d_work; } else { - PJRT_RETURN_IF_ERROR(allocate_workspace( - lwork_orgqr, "cuMemAlloc (orgqr workspace)", d_work2)); - work_ptr = reinterpret_cast(d_work2.ptr); + PJRT_RETURN_IF_ERROR(allocate_workspace(scratch, + static_cast(lwork_orgqr), + "orgqr workspace", work_ptr)); } PJRT_RETURN_IF_GPU_ERROR( CuSolver::orgqr(g)(solver.handle.get(), m, k, k, - reinterpret_cast(q_ptr), m, - reinterpret_cast(d_tau.ptr), work_ptr, - lwork_orgqr, - reinterpret_cast(solver.info.ptr)), + reinterpret_cast(q_ptr), m, d_tau, work_ptr, + lwork_orgqr, solver.info), "cusolverDn?orgqr"); return Error::Success(); } #endif // _WIN32 -static Error do_qr_cuda(void *stream, AnyBuffer input, Result q_out, - Result r_out) { +static Error do_qr_cuda(void *stream, ScratchAllocator scratch, AnyBuffer input, + Result q_out, Result r_out) { #ifdef _WIN32 return Error(ErrorCode::kUnimplemented, "CUDA QR is not supported on Windows"); #else - PJRT_DISPATCH_FLOAT(input.element_type(), qr_cuda_impl, stream, input, q_out, - r_out); + PJRT_DISPATCH_FLOAT(input.element_type(), qr_cuda_impl, stream, scratch, + input, q_out, r_out); #endif } XLA_FFI_DEFINE_HANDLER(qr_handler_cuda, do_qr_cuda, Ffi::Bind() .Ctx>() + .Ctx() .Arg() .Ret() .Ret()); diff --git a/src/svd_cuda.cpp b/src/svd_cuda.cpp index e61fb5c9..ce0fa664 100644 --- a/src/svd_cuda.cpp +++ b/src/svd_cuda.cpp @@ -26,11 +26,12 @@ namespace rpjrt { #ifndef _WIN32 template -static Error svd_cuda_impl(void *stream, AnyBuffer input, - Result u_out, Result s_out, +static Error svd_cuda_impl(void *stream, ScratchAllocator &scratch, + AnyBuffer input, Result u_out, + Result s_out, Result vt_out) { Solver solver(get_gpu_libs()); - PJRT_RETURN_IF_ERROR(solver.begin(stream)); + PJRT_RETURN_IF_ERROR(solver.begin(scratch, stream)); auto &g = solver.g; auto dims = input.dimensions(); @@ -52,9 +53,11 @@ static Error svd_cuda_impl(void *stream, AnyBuffer input, // gesvd overwrites A. Allocate a working copy so the input buffer is // preserved (XLA may have aliased it elsewhere). - DeviceMem d_a(g); - PJRT_RETURN_IF_GPU_ERROR(d_a.alloc(a_bytes), "cuMemAlloc (A)"); - PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(d_a.ptr, input_ptr, a_bytes, stream), + T *d_a; + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(m) * n, "A copy", d_a)); + PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(reinterpret_cast(d_a), + input_ptr, a_bytes, stream), "cuMemcpyDtoDAsync (input -> A)"); int lwork = 0; @@ -62,43 +65,41 @@ static Error svd_cuda_impl(void *stream, AnyBuffer input, CuSolver::gesvd_bs(g)(solver.handle.get(), m, n, &lwork), "cusolverDn?gesvd_bufferSize"); - DeviceMem d_work(g); - PJRT_RETURN_IF_ERROR( - allocate_workspace(lwork, "cuMemAlloc (gesvd workspace)", d_work)); + T *d_work; + PJRT_RETURN_IF_ERROR(allocate_workspace( + scratch, static_cast(lwork), "gesvd workspace", d_work)); // jobu / jobvt are 'S' (reduced). They're typed as signed char in the // cuSOLVER ABI; passing the literal char works because of integer // promotion at the call site. PJRT_RETURN_IF_GPU_ERROR( - CuSolver::gesvd(g)(solver.handle.get(), 'S', 'S', m, n, - reinterpret_cast(d_a.ptr), m, + CuSolver::gesvd(g)(solver.handle.get(), 'S', 'S', m, n, d_a, m, reinterpret_cast(s_ptr), reinterpret_cast(u_ptr), m, - reinterpret_cast(vt_ptr), n, - reinterpret_cast(d_work.ptr), lwork, - /*rwork=*/nullptr, - reinterpret_cast(solver.info.ptr)), + reinterpret_cast(vt_ptr), n, d_work, lwork, + /*rwork=*/nullptr, solver.info), "cusolverDn?gesvd"); return Error::Success(); } #endif // _WIN32 -static Error do_svd_cuda(void *stream, AnyBuffer input, - Result u_out, Result s_out, - Result vt_out) { +static Error do_svd_cuda(void *stream, ScratchAllocator scratch, + AnyBuffer input, Result u_out, + Result s_out, Result vt_out) { #ifdef _WIN32 return Error(ErrorCode::kUnimplemented, "CUDA SVD is not supported on Windows"); #else - PJRT_DISPATCH_FLOAT(input.element_type(), svd_cuda_impl, stream, input, u_out, - s_out, vt_out); + PJRT_DISPATCH_FLOAT(input.element_type(), svd_cuda_impl, stream, scratch, + input, u_out, s_out, vt_out); #endif } XLA_FFI_DEFINE_HANDLER(svd_handler_cuda, do_svd_cuda, Ffi::Bind() .Ctx>() + .Ctx() .Arg() .Ret() .Ret() diff --git a/vignettes/articles/custom-calls-lapack-cusolver.Rmd b/vignettes/articles/custom-calls-lapack-cusolver.Rmd index 28ff8b4d..16161eb6 100644 --- a/vignettes/articles/custom-calls-lapack-cusolver.Rmd +++ b/vignettes/articles/custom-calls-lapack-cusolver.Rmd @@ -71,7 +71,8 @@ For a new built-in custom call `foo`: The shared FFI kit -- `src/ffi_common.h`, `src/ffi_lapack.h`, `src/ffi_cusolver.{h,cpp}` -- holds the pieces every kernel uses (status macros, dtype dispatch, `dim_to_int`, the LAPACK promotion trait, the -cuSOLVER handle pool and `DeviceMem` RAII). See section 7 below. +cuSOLVER handle pool, and the scratch-allocator-backed `allocate_workspace` +helper). See section 7 below. `src/Makevars.in` adds `$(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)` to the link line; this works on Linux/macOS/Windows out of the box. The CUDA path uses @@ -109,13 +110,15 @@ void *get_foo_handler(void) { return (void *)foo_handler; } } ``` -For CUDA you also bind the platform stream so cuSOLVER can be told where to -launch: +For CUDA you also bind the platform stream (so cuSOLVER knows where to +launch) and the FFI scratch allocator (where workspace and `devInfo` come +from -- see section 6): ```c++ XLA_FFI_DEFINE_HANDLER(foo_handler_cuda, do_foo_cuda, Ffi::Bind() .Ctx>() + .Ctx() .Arg() .Ret() .Ret()); @@ -253,12 +256,17 @@ cuSOLVER runs **asynchronously** on a CUDA stream supplied by XLA. The handler: 1. Resolves the stream from `Ctx>()`. -2. Borrows a stream-bound cuSOLVER handle from a pool. -3. Allocates device memory for working buffers (input copy, `tau`, - `devInfo`, workspace). -4. Issues async D2D memcpys / kernel launches on the stream. -5. Returns immediately. The caller's stream synchronisation makes the work - visible. +2. Resolves the FFI scratch allocator from `Ctx()`. +3. Borrows a stream-bound cuSOLVER handle from a pool. +4. Asks `ScratchAllocator` for device memory for working buffers (input + copy, `tau`, `devInfo`, workspace). The scratch allocator is backed by + XLA's BFC pool, so allocations are reused across calls and visible to + XLA's memory accounting -- the kernel does not call `cuMemAlloc` itself. +5. Issues async D2D memcpys / kernel launches on the stream. +6. Returns immediately. The caller's stream synchronisation makes the work + visible. `ScratchAllocator` frees its allocations when the handler + returns; the underlying XLA allocator stream-orders the free against the + work the kernel just launched. Inputs and outputs are already on the GPU; the handler only sees device pointers (cast `untyped_data()` to `CUdeviceptr`, an `unsigned long long`). @@ -282,6 +290,11 @@ has to mirror the cuSOLVER and CUDA driver APIs by hand, and an opaque to `GpuLibs` and to the loader in `ffi_cusolver.cpp::get_gpu_libs()`, then expose them through a `CuSolver` trait specialisation. +The dlopen surface is intentionally small: cuSOLVER itself plus the few +CUDA driver helpers cuSOLVER doesn't provide (`cuMemcpyDtoDAsync_v2`, +`cuMemsetD8Async`). Allocation goes through `ffi::ScratchAllocator`, so +`cuMemAlloc`/`cuMemFree` are not loaded. + ### The cuSOLVER handle is stateful A `cusolverDnHandle_t` is bound to a stream by `cusolverDnSetStream`. If @@ -301,9 +314,11 @@ call `borrow_solver_handle(g, stream, handle)`. ### devInfo lives on the device cuSOLVER routines take a `devInfo` pointer that must be in device memory. -You must allocate `sizeof(int)` of device space and pass it. The value -itself only flags illegal arguments (mirroring LAPACK's `info < 0` case), -which the handler already validates up front via dimension checks. +You must allocate `sizeof(int)` of device space and pass it -- the +`Solver::begin(scratch, stream)` prologue handles this for you, so the +kernel just reads `solver.info`. The value itself only flags illegal +arguments (mirroring LAPACK's `info < 0` case), which the handler already +validates up front via dimension checks. `jaxlib`'s `GeqrfImpl`/`OrgqrImpl` allocate the buffer and never read its value back. We do the same: don't pay for a D2H copy + sync just to check @@ -318,10 +333,13 @@ launch-time error, matching jaxlib's getrf path. ### Status codes: check every call -Every CUDA driver call (`cuMemAlloc`, `cuMemcpyDtoDAsync`, `cuMemsetD8Async`, -`cuStreamSynchronize`) and every cuSOLVER call (including the workspace -size queries and `cusolverDnSetStream`) returns an `int` status. Non-zero -means failure. **Drop none of them.** +Every CUDA driver call (`cuMemcpyDtoDAsync`, `cuMemsetD8Async`) and every +cuSOLVER call (including the workspace size queries and +`cusolverDnSetStream`) returns an `int` status. Non-zero means failure. +**Drop none of them.** `ScratchAllocator::Allocate` returns +`std::optional` instead -- `allocate_workspace` translates the +nullopt into an `Error::Internal` for you, so kernels can use +`PJRT_RETURN_IF_ERROR` uniformly. The pattern in `ffi_cusolver.h` (used by every kernel): @@ -336,9 +354,9 @@ The pattern in `ffi_cusolver.h` (used by every kernel): } while (0) ``` -A failed `cuMemAlloc` whose status is dropped leaves the device pointer at -0; the next memcpy then writes through a null device pointer and corrupts -arbitrary device state. There is no debug message, just garbage results. +A failed memcpy or kernel launch whose status is dropped leaves the +output buffer with garbage but produces no error -- the failure surfaces +later as wrong numerical answers. There is no debug message. ### Workspace queries @@ -348,7 +366,7 @@ as LAPACK: query, allocate, run. If you call `geqrf` and then `orgqr` on the same problem, their workspace sizes differ. Query both. If `orgqr_lwork <= geqrf_lwork`, reuse the -buffer; otherwise allocate a second one. See `qr_cuda.cpp` for the +buffer; otherwise ask `scratch` for a second one. See `qr_cuda.cpp` for the two-workspace pattern. ### Integer overflow on size math @@ -385,9 +403,10 @@ size_t a_bytes = static_cast(m) * n * sizeof(T); | `int` overflow on `m * n * sizeof(T)` for large matrices | Cast to `size_t` before multiplying. | | `int64_t` -> `int` truncation on dimension cast | `dim_to_int` with bounds check. | | Read `devInfo` D2H without a sync -> get the value before the kernel ran | Either don't read it, or `cuStreamSynchronize` first. | -| `cuMemAlloc(0)` from `lwork = 0` | Guard or accept the allocation error. | +| `scratch.Allocate(0)` from `lwork = 0` -> `nullopt`, surfaces as Error | Guard zero-sized requests, or accept the (rare) error. | | Handler synchronises the stream -> kills async pipelining | Don't call `cuStreamSynchronize` unless you need a host-visible result. | -| Forgot `Ctx>()` in the FFI bind -> no stream available | Add it; cuSOLVER on the default stream serialises with everything else. | +| Forgot `Ctx>()` or `Ctx()` in the FFI bind | Both are required for a CUDA linalg kernel; the bind order must match the function arg order. | +| Holding a scratch pointer past handler return -> use-after-free | All scratch allocations are freed when the handler returns; never cache them in a static. | | `dlopen("libcusolver.so")` fails on a runtime-only CUDA install | Fall back through SONAME variants (`.so.11`, etc.) and degrade gracefully. | | In-place aliasing of input and output buffers | Always copy input to a working device buffer before factorising. | @@ -403,7 +422,7 @@ Most of the boilerplate above is in shared headers, so a new linalg op is |---|---| | `src/ffi_common.h` | `PJRT_RETURN_IF_ERROR`, `PJRT_DISPATCH_FLOAT`, `dim_to_int` | | `src/ffi_lapack.h` | LAPACK Fortran extern decls (geqrf/orgqr/getrf/gesdd/syevd, plus `s*_` on non-Windows). `Lapack` trait whose `::S` is the LAPACK storage type (Windows: always `double`; elsewhere: matches `T`). `lapack_check_info`. | -| `src/ffi_cusolver.h` | `CUdeviceptr` opaque typedef, `PJRT_RETURN_IF_GPU_ERROR`, `GpuLibs` (function-pointer table), `DeviceMem` RAII, `HandleGuard` (per-stream pool borrow), `borrow_solver_handle`, `CuSolver` dispatch trait. | +| `src/ffi_cusolver.h` | `CUdeviceptr` opaque typedef, `PJRT_RETURN_IF_GPU_ERROR`, `GpuLibs` (function-pointer table for cuSOLVER + memcpy/memset helpers), `HandleGuard` (per-stream pool borrow), `borrow_solver_handle`, `Solver` prologue (handle + scratch-allocated `devInfo`), `allocate_workspace(scratch, n_elements, name, T*&)`, `CuSolver` dispatch trait. | | `src/ffi_cusolver.cpp`| `get_gpu_libs()` (singleton dlopen loader) and the SolverHandlePool (mutex-guarded free-list keyed by stream). Shared across all CUDA kernels. | ### What an op looks like with the kit @@ -427,19 +446,26 @@ static Error do_qr(AnyBuffer input, Result q_out, Result r No `#ifdef _WIN32` in the kernel body; the trait absorbs it. -The CUDA kernel is similarly slimmer because `GpuLibs`, `DeviceMem`, -`HandleGuard`, and the pool live in `ffi_cusolver.cpp`: +The CUDA kernel is similarly slimmer because `GpuLibs`, `Solver`, +`HandleGuard`, and the pool live in `ffi_cusolver.cpp`. Workspace, +`devInfo`, and the input copy all come from the FFI scratch allocator: ```cpp template -static Error qr_cuda_impl(void *stream, AnyBuffer input, ...) { - auto &g = get_gpu_libs(); - if (!g.loaded) return Error::Internal("CUDA/cuSOLVER libraries not available"); - HandleGuard handle; - PJRT_RETURN_IF_ERROR(borrow_solver_handle(g, stream, handle)); - DeviceMem d_a(g), d_tau(g), d_info(g), d_work(g); - // ...alloc, memcpy_dtod, CuSolver::geqrf(g)(handle.get(), ...) wrapped in - // PJRT_RETURN_IF_GPU_ERROR... +static Error qr_cuda_impl(void *stream, ScratchAllocator &scratch, + AnyBuffer input, ...) { + Solver solver(get_gpu_libs()); + PJRT_RETURN_IF_ERROR(solver.begin(scratch, stream)); // handle + devInfo + auto &g = solver.g; + + T *d_a, *d_tau, *d_work; + PJRT_RETURN_IF_ERROR(allocate_workspace(scratch, m * n, "A copy", d_a)); + PJRT_RETURN_IF_ERROR(allocate_workspace(scratch, k, "tau", d_tau)); + // ...workspace bufferSize query... + PJRT_RETURN_IF_ERROR(allocate_workspace(scratch, lwork, "work", d_work)); + + // ...memcpy_dtod, CuSolver::geqrf(g)(solver.handle.get(), ..., solver.info) + // wrapped in PJRT_RETURN_IF_GPU_ERROR... } ``` @@ -526,12 +552,11 @@ To add a new built-in factorisation `foo`: `do_foo_cuda` dispatcher returns `Error::Unimplemented` on Windows. This keeps a single `[[Rcpp::export]] get_foo_handler_cuda()` valid on every platform without `#ifdef`s around the export. - - Borrow a handle: `HandleGuard handle; borrow_solver_handle(g, stream, handle);` - - `DeviceMem` for working buffers + `devInfo`. Don't read `devInfo` back - unless your op uses values beyond illegal-argument flagging. - - Workspace bufferSize, then `d_work.alloc(...)`, then the call. - - Every CUDA / cuSOLVER call wrapped in `PJRT_RETURN_IF_GPU_ERROR`. - - Bind the FFI with `.Ctx>()`. + - Open with the `Solver` prologue: `Solver solver(get_gpu_libs()); solver.begin(scratch, stream);`. This borrows a stream-bound handle and allocates `devInfo` from `scratch`. + - Use `allocate_workspace(scratch, n_elements, name, T*&)` for every working buffer (input copy, `tau`, workspace, anything you'd otherwise `cuMemAlloc`). Don't read `solver.info` back unless your op uses values beyond illegal-argument flagging. + - Workspace bufferSize, then `allocate_workspace(scratch, lwork, ...)`, then the call. + - Every memcpy / memset / cuSOLVER call wrapped in `PJRT_RETURN_IF_GPU_ERROR`; allocate calls go through `PJRT_RETURN_IF_ERROR` (the helper translates `nullopt` into `Error::Internal`). + - Bind the FFI with `.Ctx>().Ctx()` -- both are required, in that order. 5. **Regenerate Rcpp wrappers**: `Rscript -e 'Rcpp::compileAttributes(".")'`.