Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ export(shape)
export(value)
import(checkmate)
importFrom(Rcpp,sourceCpp)
importFrom(bit64,integer64)
importFrom(cli,cli_abort)
importFrom(safetensors,safe_tensor_buffer)
importFrom(safetensors,safe_tensor_meta)
Expand Down
4 changes: 2 additions & 2 deletions R/RcppExports.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,8 +169,8 @@ impl_client_buffer_from_integer <- function(client, device, data, dims, dtype) {
.Call(`_pjrt_impl_client_buffer_from_integer`, client, device, data, dims, dtype)
}

impl_client_buffer_from_integer64 <- function(client, device, data, dims) {
.Call(`_pjrt_impl_client_buffer_from_integer64`, client, device, data, dims)
impl_client_buffer_from_integer64 <- function(client, device, data, dims, dtype) {
.Call(`_pjrt_impl_client_buffer_from_integer64`, client, device, data, dims, dtype)
}

impl_client_buffer_from_logical <- function(client, device, data, dims, dtype) {
Expand Down
7 changes: 4 additions & 3 deletions R/buffer.R
Original file line number Diff line number Diff line change
Expand Up @@ -292,16 +292,17 @@ pjrt_buffer.integer64 <- function(
...
) {
args <- convert_buffer_args(data, dtype, device, shape, "i64", ...)
if (!identical(args$dtype, "i64")) {
if (!args$dtype %in% c("i64", "ui64")) {
cli_abort(
"{.cls integer64} input only supports {.val i64} dtype, got {.val {args$dtype}}."
"{.cls integer64} input only supports {.val i64} or {.val ui64} dtype, got {.val {args$dtype}}."
)
}
impl_client_buffer_from_integer64(
client = args$client,
device = args$device,
data = args$data,
dims = args$dims
dims = args$dims,
dtype = args$dtype
)
}

Expand Down
1 change: 1 addition & 0 deletions R/package.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#' @importFrom safetensors safe_tensor_buffer safe_tensor_meta
#' @importFrom utils hashtab
#' @importFrom cli cli_abort
#' @importFrom bit64 integer64
## usethis namespace: end
NULL

Expand Down
9 changes: 5 additions & 4 deletions src/RcppExports.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -502,16 +502,17 @@ BEGIN_RCPP
END_RCPP
}
// impl_client_buffer_from_integer64
Rcpp::XPtr<rpjrt::PJRTBuffer> impl_client_buffer_from_integer64(Rcpp::XPtr<rpjrt::PJRTClient> client, Rcpp::XPtr<rpjrt::PJRTDevice> device, SEXP data, std::vector<int64_t> dims);
RcppExport SEXP _pjrt_impl_client_buffer_from_integer64(SEXP clientSEXP, SEXP deviceSEXP, SEXP dataSEXP, SEXP dimsSEXP) {
Rcpp::XPtr<rpjrt::PJRTBuffer> impl_client_buffer_from_integer64(Rcpp::XPtr<rpjrt::PJRTClient> client, Rcpp::XPtr<rpjrt::PJRTDevice> device, SEXP data, std::vector<int64_t> dims, std::string dtype);
RcppExport SEXP _pjrt_impl_client_buffer_from_integer64(SEXP clientSEXP, SEXP deviceSEXP, SEXP dataSEXP, SEXP dimsSEXP, SEXP dtypeSEXP) {
BEGIN_RCPP
Rcpp::RObject rcpp_result_gen;
Rcpp::RNGScope rcpp_rngScope_gen;
Rcpp::traits::input_parameter< Rcpp::XPtr<rpjrt::PJRTClient> >::type client(clientSEXP);
Rcpp::traits::input_parameter< Rcpp::XPtr<rpjrt::PJRTDevice> >::type device(deviceSEXP);
Rcpp::traits::input_parameter< SEXP >::type data(dataSEXP);
Rcpp::traits::input_parameter< std::vector<int64_t> >::type dims(dimsSEXP);
rcpp_result_gen = Rcpp::wrap(impl_client_buffer_from_integer64(client, device, data, dims));
Rcpp::traits::input_parameter< std::string >::type dtype(dtypeSEXP);
rcpp_result_gen = Rcpp::wrap(impl_client_buffer_from_integer64(client, device, data, dims, dtype));
return rcpp_result_gen;
END_RCPP
}
Expand Down Expand Up @@ -589,7 +590,7 @@ static const R_CallMethodDef CallEntries[] = {
{"_pjrt_impl_buffer_to_host_async", (DL_FUNC) &_pjrt_impl_buffer_to_host_async, 1},
{"_pjrt_impl_loaded_executable_execute", (DL_FUNC) &_pjrt_impl_loaded_executable_execute, 3},
{"_pjrt_impl_client_buffer_from_integer", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer, 5},
{"_pjrt_impl_client_buffer_from_integer64", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer64, 4},
{"_pjrt_impl_client_buffer_from_integer64", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer64, 5},
{"_pjrt_impl_client_buffer_from_logical", (DL_FUNC) &_pjrt_impl_client_buffer_from_logical, 5},
{"_pjrt_impl_client_buffer_from_double", (DL_FUNC) &_pjrt_impl_client_buffer_from_double, 5},
{NULL, NULL, 0}
Expand Down
18 changes: 14 additions & 4 deletions src/pjrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -843,16 +843,26 @@ Rcpp::XPtr<rpjrt::PJRTBuffer> impl_client_buffer_from_integer(

// bit64::integer64 stores int64 values inside a REALSXP (8 bytes per slot),
// so we can hand the underlying buffer to PJRT zero-copy as int64.
// The bit pattern is identical for signed/unsigned 64-bit ints, so the same
// data can be uploaded as either S64 or U64.
// [[Rcpp::export()]]
Rcpp::XPtr<rpjrt::PJRTBuffer> impl_client_buffer_from_integer64(
Rcpp::XPtr<rpjrt::PJRTClient> client, Rcpp::XPtr<rpjrt::PJRTDevice> device,
SEXP data, std::vector<int64_t> dims) {
SEXP data, std::vector<int64_t> dims, std::string dtype) {
static_assert(sizeof(double) == sizeof(int64_t),
"bit64::integer64 zero-copy requires sizeof(double) == "
"sizeof(int64_t)");
return create_buffer_from_array_async_zerocopy(
client, data, REAL(data), dims, PJRT_Buffer_Type_S64, sizeof(int64_t),
false, device->device);
PJRT_Buffer_Type buffer_type;
if (dtype == "i64") {
buffer_type = PJRT_Buffer_Type_S64;
} else if (dtype == "ui64") {
buffer_type = PJRT_Buffer_Type_U64;
} else {
Rcpp::stop("Unsupported type: %s", dtype.c_str());
}
return create_buffer_from_array_async_zerocopy(client, data, REAL(data), dims,
buffer_type, sizeof(int64_t),
false, device->device);
}

// [[Rcpp::export()]]
Expand Down
10 changes: 9 additions & 1 deletion tests/testthat/test-buffer.R
Original file line number Diff line number Diff line change
Expand Up @@ -301,7 +301,7 @@ test_that("pjrt_scalar.integer64 round-trips a single 64-bit value", {
expect_error(pjrt_scalar(bit64::as.integer64(c(1, 2))), "length 1")
})

test_that("pjrt_buffer.integer64 rejects non-i64 dtype", {
test_that("pjrt_buffer.integer64 rejects non-i64/ui64 dtype", {
expect_error(
pjrt_buffer(bit64::as.integer64(1), dtype = "i32"),
"only supports.*i64"
Expand All @@ -317,6 +317,14 @@ test_that("ui64 buffers also materialize as integer64", {
expect_equal(as.character(back), c("0", "1", "100"))
})

test_that("pjrt_buffer / as_array round-trip ui64 with full 64-bit range", {
x <- bit64::as.integer64(c(0, 1, 2^32, -2^40, 9223372036854775000))
dim(x) <- 5L
buf <- pjrt_buffer(x, dtype = "ui64")
expect_equal(as.character(elt_type(buf)), "ui64")
expect_equal(as_array(buf), x)
})

test_that("raw", {
sample_signed <- function(precision, shape) {
precision <- as.integer(precision)
Expand Down