Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@

## Features

* Added QR, LU, SVD, and symmetric eigendecomposition support on both
CPU and CUDA via the FFI registration mechanism.
* Added an vignette on how to register custom calls via the FFI
registration mechanisms with coverage of both CUDA and CPU-specific
aspects.
* Added support for the `bit64` package to better support long integers.
* `pjrt_buffer()`, `pjrt_scalar()`, and `as_array()` gain a `scan_na`
argument (default `FALSE`). When `TRUE`, host → device transfers error if
Expand Down
40 changes: 40 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

get_eigh_handler <- function() {
.Call(`_pjrt_get_eigh_handler`)
}

get_eigh_handler_cuda <- function() {
.Call(`_pjrt_get_eigh_handler_cuda`)
}

impl_register_custom_call <- function(plugin, target_name, handler_ptr, platform_name) {
invisible(.Call(`_pjrt_impl_register_custom_call`, plugin, target_name, handler_ptr, platform_name))
}
Expand All @@ -21,6 +29,14 @@ format_raw_buffer_cpp <- function(data, dtype, shape) {
.Call(`_pjrt_format_raw_buffer_cpp`, data, dtype, shape)
}

get_lu_handler <- function() {
.Call(`_pjrt_get_lu_handler`)
}

get_lu_handler_cuda <- function() {
.Call(`_pjrt_get_lu_handler_cuda`)
}

impl_plugin_load <- function(path) {
.Call(`_pjrt_impl_plugin_load`, path)
}
Expand Down Expand Up @@ -181,3 +197,27 @@ impl_client_buffer_from_double <- function(client, device, data, dims, dtype) {
.Call(`_pjrt_impl_client_buffer_from_double`, client, device, data, dims, dtype)
}

get_geqrf_handler <- function() {
.Call(`_pjrt_get_geqrf_handler`)
}

get_orgqr_handler <- function() {
.Call(`_pjrt_get_orgqr_handler`)
}

get_geqrf_handler_cuda <- function() {
.Call(`_pjrt_get_geqrf_handler_cuda`)
}

get_orgqr_handler_cuda <- function() {
.Call(`_pjrt_get_orgqr_handler_cuda`)
}

get_svd_handler <- function() {
.Call(`_pjrt_get_svd_handler`)
}

get_svd_handler_cuda <- function() {
.Call(`_pjrt_get_svd_handler_cuda`)
}

30 changes: 30 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,36 @@ register_namespace_callback <- function(pkgname, namespace, callback) {
.package = pkgname
)

# Register the built-in LAPACK / cuSOLVER linear-algebra handlers. These
# are pjrt-owned custom calls that any downstream package (anvl, future
# bindings) can invoke via `stablehlo.custom_call @<target>(...)` without
# having to ship its own LAPACK linkage.
pjrt_register_custom_call(
"geqrf",
list(host = get_geqrf_handler(), cuda = get_geqrf_handler_cuda()),
.package = pkgname
)
pjrt_register_custom_call(
"orgqr",
list(host = get_orgqr_handler(), cuda = get_orgqr_handler_cuda()),
.package = pkgname
)
pjrt_register_custom_call(
"lu",
list(host = get_lu_handler(), cuda = get_lu_handler_cuda()),
.package = pkgname
)
pjrt_register_custom_call(
"svd",
list(host = get_svd_handler(), cuda = get_svd_handler_cuda()),
.package = pkgname
)
pjrt_register_custom_call(
"eigh",
list(host = get_eigh_handler(), cuda = get_eigh_handler_cuda()),
.package = pkgname
)

register_namespace_callback(pkgname, "safetensors", function(...) {
frameworks <- utils::getFromNamespace(
"safetensors_frameworks",
Expand Down
2 changes: 1 addition & 1 deletion src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CXX_STD=CXX20

PKG_CPPFLAGS=@cflags@ -Iproto -I$(R_PACKAGE_DIR)/include -I../inst/include
PKG_CXXFLAGS=$(C_VISIBILITY)
PKG_LIBS=@libs@
PKG_LIBS=@libs@ $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)

CPP_SOURCES=@cppsrc@
CC_SOURCES=@pbsrc@
Expand Down
110 changes: 110 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// get_eigh_handler
SEXP get_eigh_handler();
RcppExport SEXP _pjrt_get_eigh_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_eigh_handler());
return rcpp_result_gen;
END_RCPP
}
// get_eigh_handler_cuda
SEXP get_eigh_handler_cuda();
RcppExport SEXP _pjrt_get_eigh_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_eigh_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// impl_register_custom_call
void impl_register_custom_call(Rcpp::XPtr<rpjrt::PJRTPlugin> plugin, const std::string& target_name, SEXP handler_ptr, const std::string& platform_name);
RcppExport SEXP _pjrt_impl_register_custom_call(SEXP pluginSEXP, SEXP target_nameSEXP, SEXP handler_ptrSEXP, SEXP platform_nameSEXP) {
Expand Down Expand Up @@ -69,6 +89,26 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// get_lu_handler
SEXP get_lu_handler();
RcppExport SEXP _pjrt_get_lu_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_lu_handler());
return rcpp_result_gen;
END_RCPP
}
// get_lu_handler_cuda
SEXP get_lu_handler_cuda();
RcppExport SEXP _pjrt_get_lu_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_lu_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// impl_plugin_load
Rcpp::XPtr<rpjrt::PJRTPlugin> impl_plugin_load(const std::string& path);
RcppExport SEXP _pjrt_impl_plugin_load(SEXP pathSEXP) {
Expand Down Expand Up @@ -545,13 +585,77 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// get_geqrf_handler
SEXP get_geqrf_handler();
RcppExport SEXP _pjrt_get_geqrf_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_geqrf_handler());
return rcpp_result_gen;
END_RCPP
}
// get_orgqr_handler
SEXP get_orgqr_handler();
RcppExport SEXP _pjrt_get_orgqr_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_orgqr_handler());
return rcpp_result_gen;
END_RCPP
}
// get_geqrf_handler_cuda
SEXP get_geqrf_handler_cuda();
RcppExport SEXP _pjrt_get_geqrf_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_geqrf_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// get_orgqr_handler_cuda
SEXP get_orgqr_handler_cuda();
RcppExport SEXP _pjrt_get_orgqr_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_orgqr_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// get_svd_handler
SEXP get_svd_handler();
RcppExport SEXP _pjrt_get_svd_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_svd_handler());
return rcpp_result_gen;
END_RCPP
}
// get_svd_handler_cuda
SEXP get_svd_handler_cuda();
RcppExport SEXP _pjrt_get_svd_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_svd_handler_cuda());
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_pjrt_get_eigh_handler", (DL_FUNC) &_pjrt_get_eigh_handler, 0},
{"_pjrt_get_eigh_handler_cuda", (DL_FUNC) &_pjrt_get_eigh_handler_cuda, 0},
{"_pjrt_impl_register_custom_call", (DL_FUNC) &_pjrt_impl_register_custom_call, 4},
{"_pjrt_get_print_handler", (DL_FUNC) &_pjrt_get_print_handler, 0},
{"_pjrt_get_print_handler_cuda", (DL_FUNC) &_pjrt_get_print_handler_cuda, 0},
{"_pjrt_test_get_extension", (DL_FUNC) &_pjrt_test_get_extension, 2},
{"_pjrt_format_raw_buffer_cpp", (DL_FUNC) &_pjrt_format_raw_buffer_cpp, 3},
{"_pjrt_get_lu_handler", (DL_FUNC) &_pjrt_get_lu_handler, 0},
{"_pjrt_get_lu_handler_cuda", (DL_FUNC) &_pjrt_get_lu_handler_cuda, 0},
{"_pjrt_impl_plugin_load", (DL_FUNC) &_pjrt_impl_plugin_load, 1},
{"_pjrt_impl_plugin_client_create", (DL_FUNC) &_pjrt_impl_plugin_client_create, 2},
{"_pjrt_impl_program_load", (DL_FUNC) &_pjrt_impl_program_load, 2},
Expand Down Expand Up @@ -592,6 +696,12 @@ static const R_CallMethodDef CallEntries[] = {
{"_pjrt_impl_client_buffer_from_integer64", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer64, 4},
{"_pjrt_impl_client_buffer_from_logical", (DL_FUNC) &_pjrt_impl_client_buffer_from_logical, 5},
{"_pjrt_impl_client_buffer_from_double", (DL_FUNC) &_pjrt_impl_client_buffer_from_double, 5},
{"_pjrt_get_geqrf_handler", (DL_FUNC) &_pjrt_get_geqrf_handler, 0},
{"_pjrt_get_orgqr_handler", (DL_FUNC) &_pjrt_get_orgqr_handler, 0},
{"_pjrt_get_geqrf_handler_cuda", (DL_FUNC) &_pjrt_get_geqrf_handler_cuda, 0},
{"_pjrt_get_orgqr_handler_cuda", (DL_FUNC) &_pjrt_get_orgqr_handler_cuda, 0},
{"_pjrt_get_svd_handler", (DL_FUNC) &_pjrt_get_svd_handler, 0},
{"_pjrt_get_svd_handler_cuda", (DL_FUNC) &_pjrt_get_svd_handler_cuda, 0},
{NULL, NULL, 0}
};

Expand Down
92 changes: 92 additions & 0 deletions src/eigh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Symmetric eigendecomposition via LAPACK syevd.
//
// Input is interpreted as symmetric using only its lower triangle (we pass
// uplo = 'L'). The factorisation produces:
// V : (n, n) eigenvectors as columns (orthonormal)
// W : (n,) eigenvalues in ascending order
// such that A = V diag(W) V^T (using the lower triangle of A).
//
// jobz = 'V' (compute eigenvectors). For values-only the user can drop V.
//
// syevd needs *two* workspaces (real `work` of size lwork, integer `iwork` of
// size liwork) -- both queried with the lwork = -1 idiom.
#include <Rcpp.h>

#include <cstddef>
#include <cstring>
#include <type_traits>
#include <vector>

#include "ffi_common.h"
#include "ffi_lapack.h"

using namespace xla::ffi;

namespace rpjrt {

template <typename T>
static Error eigh_impl(AnyBuffer input, Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out) {
using S = typename Lapack<T>::S;

auto dims = input.dimensions();
int m, n;
PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m));
PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n));
if (m != n) return Error::InvalidArgument("eigh requires a square matrix");

const T *in = static_cast<const T *>(input.untyped_data());
T *v_data = static_cast<T *>((*v_out).untyped_data());
T *w_data = static_cast<T *>((*w_out).untyped_data());
std::size_t total = static_cast<std::size_t>(n) * n;

// syevd overwrites its A argument with the eigenvectors -- factor in place
// in the V output buffer. The pointer-equality guard inside promote_inplace
// covers the input-output aliasing case (see eigh_cuda.cpp:43-56 for the
// rationale -- mirrors jaxlib's `CopyIfDiffBuffer`).
std::vector<S> a_storage, w_storage;
S *a = promote_inplace<T>(a_storage, v_data, total, in);
S *w = promote_output<T>(w_storage, w_data, n);

const char jobz = 'V';
const char uplo = 'L';
int info;

int lwork = -1, liwork = -1;
S work_size;
int iwork_size;
Lapack<T>::syevd(&jobz, &uplo, &n, a, &n, w, &work_size, &lwork, &iwork_size,
&liwork, &info);
PJRT_RETURN_IF_ERROR(lapack_check_info(info, "syevd workspace query"));

lwork = static_cast<int>(work_size);
liwork = iwork_size;
std::vector<S> work(lwork);
std::vector<int> iwork(liwork);
Lapack<T>::syevd(&jobz, &uplo, &n, a, &n, w, work.data(), &lwork,
iwork.data(), &liwork, &info);
PJRT_RETURN_IF_ERROR(lapack_check_info(info, "syevd"));

demote_output<T>(a_storage, v_data, total);
demote_output<T>(w_storage, w_data, n);

return Error::Success();
}

static Error do_eigh(AnyBuffer input, Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out) {
PJRT_DISPATCH_FLOAT(input.element_type(), eigh_impl, input, v_out, w_out);
}

XLA_FFI_DEFINE_HANDLER(eigh_handler, do_eigh,
Ffi::Bind()
.Arg<AnyBuffer>() // symmetric matrix
.Ret<AnyBuffer>() // eigenvectors (n, n)
.Ret<AnyBuffer>()); // eigenvalues (n,)

} // namespace rpjrt

// [[Rcpp::export]]
SEXP get_eigh_handler() {
return R_MakeExternalPtr((void *)rpjrt::eigh_handler, R_NilValue, R_NilValue);
}
Loading
Loading