diff --git a/NAMESPACE b/NAMESPACE index fd539fe3..1c8625b0 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -82,6 +82,7 @@ export(shape) export(value) import(checkmate) importFrom(Rcpp,sourceCpp) +importFrom(bit64,integer64) importFrom(cli,cli_abort) importFrom(safetensors,safe_tensor_buffer) importFrom(safetensors,safe_tensor_meta) diff --git a/R/RcppExports.R b/R/RcppExports.R index 0ad3dfe3..91c9033c 100644 --- a/R/RcppExports.R +++ b/R/RcppExports.R @@ -169,8 +169,8 @@ impl_client_buffer_from_integer <- function(client, device, data, dims, dtype) { .Call(`_pjrt_impl_client_buffer_from_integer`, client, device, data, dims, dtype) } -impl_client_buffer_from_integer64 <- function(client, device, data, dims) { - .Call(`_pjrt_impl_client_buffer_from_integer64`, client, device, data, dims) +impl_client_buffer_from_integer64 <- function(client, device, data, dims, dtype) { + .Call(`_pjrt_impl_client_buffer_from_integer64`, client, device, data, dims, dtype) } impl_client_buffer_from_logical <- function(client, device, data, dims, dtype) { diff --git a/R/buffer.R b/R/buffer.R index 878718bd..31650272 100644 --- a/R/buffer.R +++ b/R/buffer.R @@ -292,16 +292,17 @@ pjrt_buffer.integer64 <- function( ... ) { args <- convert_buffer_args(data, dtype, device, shape, "i64", ...) - if (!identical(args$dtype, "i64")) { + if (!args$dtype %in% c("i64", "ui64")) { cli_abort( - "{.cls integer64} input only supports {.val i64} dtype, got {.val {args$dtype}}." + "{.cls integer64} input only supports {.val i64} or {.val ui64} dtype, got {.val {args$dtype}}." ) } impl_client_buffer_from_integer64( client = args$client, device = args$device, data = args$data, - dims = args$dims + dims = args$dims, + dtype = args$dtype ) } diff --git a/R/package.R b/R/package.R index 6e9cd98d..b3ae2394 100644 --- a/R/package.R +++ b/R/package.R @@ -6,6 +6,7 @@ #' @importFrom safetensors safe_tensor_buffer safe_tensor_meta #' @importFrom utils hashtab #' @importFrom cli cli_abort +#' @importFrom bit64 integer64 ## usethis namespace: end NULL diff --git a/src/RcppExports.cpp b/src/RcppExports.cpp index cc608e13..649344b7 100644 --- a/src/RcppExports.cpp +++ b/src/RcppExports.cpp @@ -502,8 +502,8 @@ BEGIN_RCPP END_RCPP } // impl_client_buffer_from_integer64 -Rcpp::XPtr impl_client_buffer_from_integer64(Rcpp::XPtr client, Rcpp::XPtr device, SEXP data, std::vector dims); -RcppExport SEXP _pjrt_impl_client_buffer_from_integer64(SEXP clientSEXP, SEXP deviceSEXP, SEXP dataSEXP, SEXP dimsSEXP) { +Rcpp::XPtr impl_client_buffer_from_integer64(Rcpp::XPtr client, Rcpp::XPtr device, SEXP data, std::vector dims, std::string dtype); +RcppExport SEXP _pjrt_impl_client_buffer_from_integer64(SEXP clientSEXP, SEXP deviceSEXP, SEXP dataSEXP, SEXP dimsSEXP, SEXP dtypeSEXP) { BEGIN_RCPP Rcpp::RObject rcpp_result_gen; Rcpp::RNGScope rcpp_rngScope_gen; @@ -511,7 +511,8 @@ BEGIN_RCPP Rcpp::traits::input_parameter< Rcpp::XPtr >::type device(deviceSEXP); Rcpp::traits::input_parameter< SEXP >::type data(dataSEXP); Rcpp::traits::input_parameter< std::vector >::type dims(dimsSEXP); - rcpp_result_gen = Rcpp::wrap(impl_client_buffer_from_integer64(client, device, data, dims)); + Rcpp::traits::input_parameter< std::string >::type dtype(dtypeSEXP); + rcpp_result_gen = Rcpp::wrap(impl_client_buffer_from_integer64(client, device, data, dims, dtype)); return rcpp_result_gen; END_RCPP } @@ -589,7 +590,7 @@ static const R_CallMethodDef CallEntries[] = { {"_pjrt_impl_buffer_to_host_async", (DL_FUNC) &_pjrt_impl_buffer_to_host_async, 1}, {"_pjrt_impl_loaded_executable_execute", (DL_FUNC) &_pjrt_impl_loaded_executable_execute, 3}, {"_pjrt_impl_client_buffer_from_integer", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer, 5}, - {"_pjrt_impl_client_buffer_from_integer64", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer64, 4}, + {"_pjrt_impl_client_buffer_from_integer64", (DL_FUNC) &_pjrt_impl_client_buffer_from_integer64, 5}, {"_pjrt_impl_client_buffer_from_logical", (DL_FUNC) &_pjrt_impl_client_buffer_from_logical, 5}, {"_pjrt_impl_client_buffer_from_double", (DL_FUNC) &_pjrt_impl_client_buffer_from_double, 5}, {NULL, NULL, 0} diff --git a/src/pjrt.cpp b/src/pjrt.cpp index 40fcdcce..302b4249 100644 --- a/src/pjrt.cpp +++ b/src/pjrt.cpp @@ -843,16 +843,26 @@ Rcpp::XPtr impl_client_buffer_from_integer( // bit64::integer64 stores int64 values inside a REALSXP (8 bytes per slot), // so we can hand the underlying buffer to PJRT zero-copy as int64. +// The bit pattern is identical for signed/unsigned 64-bit ints, so the same +// data can be uploaded as either S64 or U64. // [[Rcpp::export()]] Rcpp::XPtr impl_client_buffer_from_integer64( Rcpp::XPtr client, Rcpp::XPtr device, - SEXP data, std::vector dims) { + SEXP data, std::vector dims, std::string dtype) { static_assert(sizeof(double) == sizeof(int64_t), "bit64::integer64 zero-copy requires sizeof(double) == " "sizeof(int64_t)"); - return create_buffer_from_array_async_zerocopy( - client, data, REAL(data), dims, PJRT_Buffer_Type_S64, sizeof(int64_t), - false, device->device); + PJRT_Buffer_Type buffer_type; + if (dtype == "i64") { + buffer_type = PJRT_Buffer_Type_S64; + } else if (dtype == "ui64") { + buffer_type = PJRT_Buffer_Type_U64; + } else { + Rcpp::stop("Unsupported type: %s", dtype.c_str()); + } + return create_buffer_from_array_async_zerocopy(client, data, REAL(data), dims, + buffer_type, sizeof(int64_t), + false, device->device); } // [[Rcpp::export()]] diff --git a/tests/testthat/test-buffer.R b/tests/testthat/test-buffer.R index 1807c536..a5b41579 100644 --- a/tests/testthat/test-buffer.R +++ b/tests/testthat/test-buffer.R @@ -301,7 +301,7 @@ test_that("pjrt_scalar.integer64 round-trips a single 64-bit value", { expect_error(pjrt_scalar(bit64::as.integer64(c(1, 2))), "length 1") }) -test_that("pjrt_buffer.integer64 rejects non-i64 dtype", { +test_that("pjrt_buffer.integer64 rejects non-i64/ui64 dtype", { expect_error( pjrt_buffer(bit64::as.integer64(1), dtype = "i32"), "only supports.*i64" @@ -317,6 +317,14 @@ test_that("ui64 buffers also materialize as integer64", { expect_equal(as.character(back), c("0", "1", "100")) }) +test_that("pjrt_buffer / as_array round-trip ui64 with full 64-bit range", { + x <- bit64::as.integer64(c(0, 1, 2^32, -2^40, 9223372036854775000)) + dim(x) <- 5L + buf <- pjrt_buffer(x, dtype = "ui64") + expect_equal(as.character(elt_type(buf)), "ui64") + expect_equal(as_array(buf), x) +}) + test_that("raw", { sample_signed <- function(precision, shape) { precision <- as.integer(precision)