Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 32 additions & 0 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
@@ -1,6 +1,14 @@
# Generated by using Rcpp::compileAttributes() -> do not edit by hand
# Generator token: 10BE3573-1514-4C36-9D1C-5A225CD40393

get_eigh_handler <- function() {
.Call(`_pjrt_get_eigh_handler`)
}

get_eigh_handler_cuda <- function() {
.Call(`_pjrt_get_eigh_handler_cuda`)
}

impl_register_custom_call <- function(plugin, target_name, handler_ptr, platform_name) {
invisible(.Call(`_pjrt_impl_register_custom_call`, plugin, target_name, handler_ptr, platform_name))
}
Expand All @@ -21,6 +29,14 @@ format_raw_buffer_cpp <- function(data, dtype, shape) {
.Call(`_pjrt_format_raw_buffer_cpp`, data, dtype, shape)
}

get_lu_handler <- function() {
.Call(`_pjrt_get_lu_handler`)
}

get_lu_handler_cuda <- function() {
.Call(`_pjrt_get_lu_handler_cuda`)
}

impl_plugin_load <- function(path) {
.Call(`_pjrt_impl_plugin_load`, path)
}
Expand Down Expand Up @@ -177,3 +193,19 @@ impl_client_buffer_from_double <- function(client, device, data, dims, dtype) {
.Call(`_pjrt_impl_client_buffer_from_double`, client, device, data, dims, dtype)
}

get_qr_handler <- function() {
.Call(`_pjrt_get_qr_handler`)
}

get_qr_handler_cuda <- function() {
.Call(`_pjrt_get_qr_handler_cuda`)
}

get_svd_handler <- function() {
.Call(`_pjrt_get_svd_handler`)
}

get_svd_handler_cuda <- function() {
.Call(`_pjrt_get_svd_handler_cuda`)
}

14 changes: 14 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,20 @@ register_namespace_callback <- function(pkgname, namespace, callback) {
.package = pkgname
)

# Register the built-in LAPACK / cuSOLVER linear-algebra handlers. These
# are pjrt-owned custom calls that any downstream package (anvl, future
# bindings) can invoke via `stablehlo.custom_call @<target>(...)` without
# having to ship its own LAPACK linkage.
register_linalg_handler <- function(target, host, cuda) {
handlers <- list(host = host)
if (!is.null(cuda)) handlers$cuda <- cuda
pjrt_register_custom_call(target, handlers, .package = pkgname)
}
register_linalg_handler("qr", get_qr_handler(), get_qr_handler_cuda())
register_linalg_handler("lu", get_lu_handler(), get_lu_handler_cuda())
register_linalg_handler("svd", get_svd_handler(), get_svd_handler_cuda())
register_linalg_handler("eigh", get_eigh_handler(), get_eigh_handler_cuda())

register_namespace_callback(pkgname, "safetensors", function(...) {
frameworks <- utils::getFromNamespace(
"safetensors_frameworks",
Expand Down
2 changes: 1 addition & 1 deletion src/Makevars.in
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ CXX_STD=CXX20

PKG_CPPFLAGS=@cflags@ -Iproto -I$(R_PACKAGE_DIR)/include -I../inst/include
PKG_CXXFLAGS=$(C_VISIBILITY)
PKG_LIBS=@libs@
PKG_LIBS=@libs@ $(LAPACK_LIBS) $(BLAS_LIBS) $(FLIBS)

CPP_SOURCES=@cppsrc@
CC_SOURCES=@pbsrc@
Expand Down
88 changes: 88 additions & 0 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ Rcpp::Rostream<true>& Rcpp::Rcout = Rcpp::Rcpp_cout_get();
Rcpp::Rostream<false>& Rcpp::Rcerr = Rcpp::Rcpp_cerr_get();
#endif

// get_eigh_handler
SEXP get_eigh_handler();
RcppExport SEXP _pjrt_get_eigh_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_eigh_handler());
return rcpp_result_gen;
END_RCPP
}
// get_eigh_handler_cuda
SEXP get_eigh_handler_cuda();
RcppExport SEXP _pjrt_get_eigh_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_eigh_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// impl_register_custom_call
void impl_register_custom_call(Rcpp::XPtr<rpjrt::PJRTPlugin> plugin, const std::string& target_name, SEXP handler_ptr, const std::string& platform_name);
RcppExport SEXP _pjrt_impl_register_custom_call(SEXP pluginSEXP, SEXP target_nameSEXP, SEXP handler_ptrSEXP, SEXP platform_nameSEXP) {
Expand Down Expand Up @@ -69,6 +89,26 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// get_lu_handler
SEXP get_lu_handler();
RcppExport SEXP _pjrt_get_lu_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_lu_handler());
return rcpp_result_gen;
END_RCPP
}
// get_lu_handler_cuda
SEXP get_lu_handler_cuda();
RcppExport SEXP _pjrt_get_lu_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_lu_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// impl_plugin_load
Rcpp::XPtr<rpjrt::PJRTPlugin> impl_plugin_load(const std::string& path);
RcppExport SEXP _pjrt_impl_plugin_load(SEXP pathSEXP) {
Expand Down Expand Up @@ -531,13 +571,57 @@ BEGIN_RCPP
return rcpp_result_gen;
END_RCPP
}
// get_qr_handler
SEXP get_qr_handler();
RcppExport SEXP _pjrt_get_qr_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_qr_handler());
return rcpp_result_gen;
END_RCPP
}
// get_qr_handler_cuda
SEXP get_qr_handler_cuda();
RcppExport SEXP _pjrt_get_qr_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_qr_handler_cuda());
return rcpp_result_gen;
END_RCPP
}
// get_svd_handler
SEXP get_svd_handler();
RcppExport SEXP _pjrt_get_svd_handler() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_svd_handler());
return rcpp_result_gen;
END_RCPP
}
// get_svd_handler_cuda
SEXP get_svd_handler_cuda();
RcppExport SEXP _pjrt_get_svd_handler_cuda() {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
rcpp_result_gen = Rcpp::wrap(get_svd_handler_cuda());
return rcpp_result_gen;
END_RCPP
}

static const R_CallMethodDef CallEntries[] = {
{"_pjrt_get_eigh_handler", (DL_FUNC) &_pjrt_get_eigh_handler, 0},
{"_pjrt_get_eigh_handler_cuda", (DL_FUNC) &_pjrt_get_eigh_handler_cuda, 0},
{"_pjrt_impl_register_custom_call", (DL_FUNC) &_pjrt_impl_register_custom_call, 4},
{"_pjrt_get_print_handler", (DL_FUNC) &_pjrt_get_print_handler, 0},
{"_pjrt_get_print_handler_cuda", (DL_FUNC) &_pjrt_get_print_handler_cuda, 0},
{"_pjrt_test_get_extension", (DL_FUNC) &_pjrt_test_get_extension, 2},
{"_pjrt_format_raw_buffer_cpp", (DL_FUNC) &_pjrt_format_raw_buffer_cpp, 3},
{"_pjrt_get_lu_handler", (DL_FUNC) &_pjrt_get_lu_handler, 0},
{"_pjrt_get_lu_handler_cuda", (DL_FUNC) &_pjrt_get_lu_handler_cuda, 0},
{"_pjrt_impl_plugin_load", (DL_FUNC) &_pjrt_impl_plugin_load, 1},
{"_pjrt_impl_plugin_client_create", (DL_FUNC) &_pjrt_impl_plugin_client_create, 2},
{"_pjrt_impl_program_load", (DL_FUNC) &_pjrt_impl_program_load, 2},
Expand Down Expand Up @@ -577,6 +661,10 @@ static const R_CallMethodDef CallEntries[] = {
{"_pjrt_impl_client_buffer_from_integer", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer, 5},
{"_pjrt_impl_client_buffer_from_logical", (DL_FUNC) &_pjrt_impl_client_buffer_from_logical, 5},
{"_pjrt_impl_client_buffer_from_double", (DL_FUNC) &_pjrt_impl_client_buffer_from_double, 5},
{"_pjrt_get_qr_handler", (DL_FUNC) &_pjrt_get_qr_handler, 0},
{"_pjrt_get_qr_handler_cuda", (DL_FUNC) &_pjrt_get_qr_handler_cuda, 0},
{"_pjrt_get_svd_handler", (DL_FUNC) &_pjrt_get_svd_handler, 0},
{"_pjrt_get_svd_handler_cuda", (DL_FUNC) &_pjrt_get_svd_handler_cuda, 0},
{NULL, NULL, 0}
};

Expand Down
92 changes: 92 additions & 0 deletions src/eigh.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
// Symmetric eigendecomposition via LAPACK syevd.
//
// Input is interpreted as symmetric using only its lower triangle (we pass
// uplo = 'L'). The factorisation produces:
// V : (n, n) eigenvectors as columns (orthonormal)
// W : (n,) eigenvalues in ascending order
// such that A = V diag(W) V^T (using the lower triangle of A).
//
// jobz = 'V' (compute eigenvectors). For values-only the user can drop V.
//
// syevd needs *two* workspaces (real `work` of size lwork, integer `iwork` of
// size liwork) -- both queried with the lwork = -1 idiom.
#include <Rcpp.h>

#include "ffi_common.h"
#include "ffi_lapack.h"

#include <cstddef>
#include <vector>

using namespace xla::ffi;

namespace rpjrt {

template <typename T>
static Error eigh_impl(AnyBuffer input, Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out) {
using S = typename Lapack<T>::S;

auto dims = input.dimensions();
int m, n;
PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m));
PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n));
if (m != n)
return Error::InvalidArgument("eigh requires a square matrix");

// syevd overwrites A with eigenvectors. Copy input -> V output buffer and
// factorise in place there.
std::vector<S> a(static_cast<std::size_t>(n) * n);
const T *in = static_cast<const T *>(input.untyped_data());
for (std::size_t i = 0; i < a.size(); i++)
a[i] = static_cast<S>(in[i]);

std::vector<S> w(n);

const char jobz = 'V';
const char uplo = 'L';
int info;

int lwork = -1, liwork = -1;
S work_size;
int iwork_size;
Lapack<T>::syevd(&jobz, &uplo, &n, a.data(), &n, w.data(), &work_size, &lwork,
&iwork_size, &liwork, &info);
PJRT_RETURN_IF_ERROR(lapack_check_info(info, "syevd workspace query"));

lwork = static_cast<int>(work_size);
liwork = iwork_size;
std::vector<S> work(lwork);
std::vector<int> iwork(liwork);
Lapack<T>::syevd(&jobz, &uplo, &n, a.data(), &n, w.data(), work.data(),
&lwork, iwork.data(), &liwork, &info);
PJRT_RETURN_IF_ERROR(lapack_check_info(info, "syevd"));

T *v_data = static_cast<T *>((*v_out).untyped_data());
for (std::size_t i = 0; i < a.size(); i++)
v_data[i] = static_cast<T>(a[i]);

T *w_data = static_cast<T *>((*w_out).untyped_data());
for (int i = 0; i < n; i++)
w_data[i] = static_cast<T>(w[i]);

return Error::Success();
}

static Error do_eigh(AnyBuffer input, Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out) {
PJRT_DISPATCH_FLOAT(input.element_type(), eigh_impl, input, v_out, w_out);
}

XLA_FFI_DEFINE_HANDLER(eigh_handler, do_eigh,
Ffi::Bind()
.Arg<AnyBuffer>() // symmetric matrix
.Ret<AnyBuffer>() // eigenvectors (n, n)
.Ret<AnyBuffer>()); // eigenvalues (n,)

} // namespace rpjrt

// [[Rcpp::export]]
SEXP get_eigh_handler() {
return R_MakeExternalPtr((void *)rpjrt::eigh_handler, R_NilValue, R_NilValue);
}
99 changes: 99 additions & 0 deletions src/eigh_cuda.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,99 @@
// CUDA symmetric eigendecomposition via cuSOLVER syevd.
//
// jobz / uplo are passed as int enum values:
// jobz = CUSOLVER_EIG_MODE_VECTOR (1)
// uplo = CUBLAS_FILL_MODE_LOWER (0)
// matching the LAPACK 'V' / 'L' choice in src/eigh.cpp.
#include <Rcpp.h>

#include "ffi_common.h"

#ifndef _WIN32
#include "ffi_cusolver.h"

#include <cstddef>
#endif

using namespace xla::ffi;

namespace rpjrt {

#ifndef _WIN32
template <typename T>
static Error eigh_cuda_impl(void *stream, ScratchAllocator &scratch,
AnyBuffer input, Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out) {
Solver solver(get_gpu_libs());
PJRT_RETURN_IF_ERROR(solver.begin(scratch, stream));
auto &g = solver.g;

auto dims = input.dimensions();
int m, n;
PJRT_RETURN_IF_ERROR(dim_to_int(dims[0], "rows", m));
PJRT_RETURN_IF_ERROR(dim_to_int(dims[1], "cols", n));
if (m != n)
return Error::InvalidArgument("eigh requires a square matrix");

auto input_ptr = reinterpret_cast<CUdeviceptr>(input.untyped_data());
auto v_ptr = reinterpret_cast<CUdeviceptr>((*v_out).untyped_data());
auto w_ptr = reinterpret_cast<CUdeviceptr>((*w_out).untyped_data());

std::size_t a_bytes = static_cast<std::size_t>(n) * n * sizeof(T);

// syevd overwrites its A argument with eigenvectors -- copy input into V
// first, factorise in place there.
PJRT_RETURN_IF_GPU_ERROR(g.memcpy_dtod(v_ptr, input_ptr, a_bytes, stream),
"cuMemcpyDtoDAsync (input -> V)");

const int jobz = 1; // CUSOLVER_EIG_MODE_VECTOR
const int uplo = 0; // CUBLAS_FILL_MODE_LOWER

int lwork = 0;
PJRT_RETURN_IF_GPU_ERROR(
CuSolver<T>::syevd_bs(g)(solver.handle.get(), jobz, uplo, n,
reinterpret_cast<const T *>(v_ptr), n,
reinterpret_cast<const T *>(w_ptr), &lwork),
"cusolverDn?syevd_bufferSize");

T *d_work;
PJRT_RETURN_IF_ERROR(allocate_workspace<T>(
scratch, static_cast<std::size_t>(lwork), "syevd workspace", d_work));

PJRT_RETURN_IF_GPU_ERROR(
CuSolver<T>::syevd(g)(solver.handle.get(), jobz, uplo, n,
reinterpret_cast<T *>(v_ptr), n,
reinterpret_cast<T *>(w_ptr), d_work, lwork,
solver.info),
"cusolverDn?syevd");

return Error::Success();
}
#endif // _WIN32

static Error do_eigh_cuda(void *stream, ScratchAllocator scratch,
AnyBuffer input, Result<AnyBuffer> v_out,
Result<AnyBuffer> w_out) {
#ifdef _WIN32
return Error(ErrorCode::kUnimplemented,
"CUDA eigh is not supported on Windows");
#else
PJRT_DISPATCH_FLOAT(input.element_type(), eigh_cuda_impl, stream, scratch,
input, v_out, w_out);
#endif
}

XLA_FFI_DEFINE_HANDLER(eigh_handler_cuda, do_eigh_cuda,
Ffi::Bind()
.Ctx<PlatformStream<void *>>()
.Ctx<ScratchAllocator>()
.Arg<AnyBuffer>()
.Ret<AnyBuffer>()
.Ret<AnyBuffer>());

} // namespace rpjrt

// [[Rcpp::export]]
SEXP get_eigh_handler_cuda() {
return R_MakeExternalPtr((void *)rpjrt::eigh_handler_cuda, R_NilValue,
R_NilValue);
}
Loading
Loading