Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -8,14 +8,15 @@
registration mechanisms with coverage of both CUDA and CPU-specific
aspects.
* Added support for the `bit64` package to better support long integers.
* `pjrt_buffer()`, `pjrt_scalar()`, and `as_array()` gain a `scan_na`
argument (default `FALSE`). When `TRUE`, host → device transfers error if
the input contains any `NA` values; device → host transfers error if a
materialized `i32`, `ui32`, `i64`, or `ui64` buffer surfaces a value that
R cannot distinguish from `NA` (`INT_MIN` for the 32-bit dtypes,
`INT64_MIN` for the 64-bit dtypes — the latter via `bit64::integer64`).
Opt-in safety check for callers that want to fail loudly on
silent NA collisions.
* `pjrt_buffer()`, `pjrt_scalar()`, and `as_array()` gain a `check`
argument (default `FALSE`). When `TRUE`, the call errors instead of
silently losing information: on input if `data` contains `NA`s, on
output if the materialized R vector contains a value that's
indistinguishable from `NA` or that has wrapped through the integer
container.
* `as_array()` on a `ui32` buffer now returns a `bit64::integer64`
instead of a base `integer`, so values `>= 2^31` round-trip losslessly
rather than wrapping to negative.

# pjrt 0.3.0

Expand Down
2 changes: 1 addition & 1 deletion R/async.R
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ value.PJRTArrayPromise <- function(x, ...) {
if (is.null(x$materialized)) {
impl_host_data_await(x$data)
out <- impl_raw_to_array(x$data, x$dtype, x$shape)
if (x$dtype %in% c("i64", "ui64")) {
if (x$dtype %in% c("i64", "ui64", "ui32")) {
class(out) <- "integer64"
}
x$materialized <- out
Expand Down
84 changes: 52 additions & 32 deletions R/buffer.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,7 +61,7 @@ is_buffer <- function(x) {
#' case the first device for that platform is used.
#' The default is to use the CPU platform, but this can be configured via the `PJRT_PLATFORM`
#' environment variable.
#' @param scan_na (`logical(1)`)\cr
#' @param check (`logical(1)`)\cr
#' If `TRUE`, scan `data` for `NA` values before transferring to the device and
#' raise an error if any are present. R's `NA` markers have no representation
#' at the XLA level (e.g. `NA_integer_` is just the bit pattern `-2147483648`,
Expand Down Expand Up @@ -95,19 +95,19 @@ pjrt_buffer <- function(
dtype = NULL,
device = NULL,
shape = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
UseMethod("pjrt_buffer")
}

check_scan_na <- function(data, scan_na) {
assert_flag(scan_na)
if (scan_na && anyNA(data)) {
check_input_na <- function(data, check) {
assert_flag(check)
if (check && anyNA(data)) {
n_na <- sum(is.na(data))
cli_abort(c(
"Input {.arg data} contains {n_na} {.val NA} value{?s}, which {?has/have} no representation at the XLA level.",
i = "Replace or drop missing values before transferring, or set {.code scan_na = FALSE} to skip this check."
i = "Replace or drop missing values before transferring, or set {.code check = FALSE} to skip this check."
))
}
invisible(NULL)
Expand Down Expand Up @@ -144,7 +144,7 @@ pjrt_buffer.PJRTBuffer <- buffer_identity
#' scalar <- pjrt_scalar(42, dtype = "f32")
#' scalar
#' @export
pjrt_scalar <- function(data, dtype = NULL, device = NULL, scan_na = FALSE, ...) {
pjrt_scalar <- function(data, dtype = NULL, device = NULL, check = FALSE, ...) {
UseMethod("pjrt_scalar")
}

Expand Down Expand Up @@ -244,10 +244,10 @@ pjrt_buffer.logical <- function(
dtype = NULL,
device = NULL,
shape = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
check_scan_na(data, scan_na)
check_input_na(data, check)
args <- convert_buffer_args(data, dtype, device, shape, "pred", ...)
buffer <- do.call(impl_client_buffer_from_logical, args)
buffer
Expand All @@ -259,10 +259,10 @@ pjrt_buffer.integer <- function(
dtype = NULL,
device = NULL,
shape = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
check_scan_na(data, scan_na)
check_input_na(data, check)
args <- convert_buffer_args(data, dtype, device, shape, "i32", ...)
buffer <- do.call(impl_client_buffer_from_integer, args)
buffer
Expand All @@ -274,10 +274,10 @@ pjrt_buffer.numeric <- function(
dtype = NULL,
device = NULL,
shape = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
check_scan_na(data, scan_na)
check_input_na(data, check)
args <- convert_buffer_args(data, dtype, device, shape, "f32", ...)
buffer <- do.call(impl_client_buffer_from_double, args)
buffer
Expand Down Expand Up @@ -344,13 +344,13 @@ pjrt_scalar.logical <- function(
data,
dtype = NULL,
device = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
if (length(data) != 1) {
cli_abort("data must have length 1")
}
check_scan_na(data, scan_na)
check_input_na(data, check)
args <- convert_buffer_args(data, dtype, device, integer(), "pred", ...)
buffer <- do.call(impl_client_buffer_from_logical, args)
buffer
Expand All @@ -361,13 +361,13 @@ pjrt_scalar.integer <- function(
data,
dtype = NULL,
device = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
if (length(data) != 1) {
cli_abort("data must have length 1")
}
check_scan_na(data, scan_na)
check_input_na(data, check)
args <- convert_buffer_args(data, dtype, device, integer(), "i32", ...)
buffer <- do.call(impl_client_buffer_from_integer, args)
buffer
Expand All @@ -378,13 +378,13 @@ pjrt_scalar.numeric <- function(
data,
dtype = NULL,
device = NULL,
scan_na = FALSE,
check = FALSE,
...
) {
if (length(data) != 1) {
cli_abort("data must have length 1")
}
check_scan_na(data, scan_na)
check_input_na(data, check)
args <- convert_buffer_args(data, dtype, device, integer(), "f32", ...)
buffer <- do.call(impl_client_buffer_from_double, args)
buffer
Expand Down Expand Up @@ -440,26 +440,46 @@ elt_type <- function(x) {
#'
#' @param x ([`PJRTBuffer`][pjrt_buffer])\cr
#' Buffer to convert.
#' @param scan_na (`logical(1)`)\cr
#' If `TRUE` and the buffer dtype is one of the four integer dtypes that
#' round-trip through a signed R container (`i32` / `ui32` via `integer`,
#' `i64` / `ui64` via `bit64::integer64`), scan the materialized vector
#' for the reserved NA bit pattern (`INT_MIN` or `INT64_MIN`) and raise an
#' error if any are present. No-op for float, boolean, and small-integer
#' dtypes (which have no NA-collision risk).
#' @param check (`logical(1)`)\cr
#' If `TRUE`, sanity-check the materialized R vector against losing
#' information across the device-to-host boundary, and abort if any
#' problematic value is detected:
#' * **`i32` / `i64`**: any `NA` in the result. R's `NA_integer_` shares
#' the bit pattern `INT_MIN`; `bit64`'s `NA_integer64_` shares
#' `INT64_MIN`. A legitimate device value at those bit patterns is
#' indistinguishable from `NA` once materialized in R.
#' * **`ui64`**: any negative value in the result. `ui64` is stored as
#' `bit64::integer64` (signed 64-bit), which wraps values `>= 2^63`
#' to negative — exactly `2^63` becomes `NA_integer64_`, anything
#' above becomes a non-NA negative integer64.
#'
#' No-op for float, boolean, and small/unsigned-32 integer dtypes —
#' `ui32` is now stored as `integer64` and has full headroom, so it
#' cannot produce a wrapped or NA value.
#' @param ... Additional arguments (unused).
#' @return An R `array` (or `vector` for shape `integer()`).
#' @export
as_array.PJRTBuffer <- function(x, scan_na = FALSE, ...) {
as_array.PJRTBuffer <- function(x, check = FALSE, ...) {
result <- value(as_array_async(x))
assert_flag(scan_na)
if (scan_na) {
assert_flag(check)
if (check) {
dt <- as.character(elt_type(x))
if (dt %in% c("i32", "ui32", "i64", "ui64") && anyNA(result)) {
if (dt %in% c("i32", "i64") && anyNA(result)) {
cli_abort(c(
"Materialized {.cls {dt}} buffer contains a value that R cannot distinguish from {.val NA}.",
i = "{.val i32}/{.val ui32} reserve the bit pattern {.val -2147483648} ({.code INT_MIN}); {.val i64}/{.val ui64} reserve {.val -9223372036854775808} ({.code INT64_MIN}).",
i = "This collision is irrecoverable: the device value and {.val NA} are indistinguishable in R. Set {.code scan_na = FALSE} to skip this check."
i = "{.val i32} reserves the bit pattern {.val -2147483648} ({.code INT_MIN}); {.val i64} reserves {.val -9223372036854775808} ({.code INT64_MIN}).",
i = "Set {.code check = FALSE} to skip this check."
))
} else if (identical(dt, "ui64") && (anyNA(result) || any(result < 0, na.rm = TRUE))) {
# ui64 values >= 2^63 wrap when stored as signed int64 — exactly 2^63
# becomes NA_integer64_ (INT64_MIN); 2^63 + k becomes a non-NA negative
# int64. Either way, the unsigned magnitude was lost.
# (ui32 is now materialized as integer64 and has full headroom, so it
# cannot produce a negative value; no check needed.)
cli_abort(c(
"Materialized {.cls ui64} buffer contains a value `>= 2^63` that wrapped through R's signed {.cls integer64}.",
i = "Exactly {.code 2^63} becomes {.code NA_integer64_}; larger values become negative {.cls integer64}.",
i = "Set {.code check = FALSE} to skip this check."
))
}
}
Expand Down
27 changes: 19 additions & 8 deletions man/as_array.PJRTBuffer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

6 changes: 3 additions & 3 deletions man/pjrt_buffer.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

15 changes: 10 additions & 5 deletions src/pjrt.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -315,10 +315,12 @@ SEXP raw_to_array_impl(const uint8_t *raw_data,
void *out_data;

if (r_type == REALSXP) {
if constexpr (std::is_same_v<T, int64_t> || std::is_same_v<T, uint64_t>) {
// 64-bit integer -> "pseudo-double": REALSXP slots carry the int64 bit
// pattern (the storage layout used by bit64::integer64). The R caller
// attaches the integer64 class.
if constexpr (std::is_same_v<T, int64_t> || std::is_same_v<T, uint64_t> ||
std::is_same_v<T, uint32_t>) {
// Integer dtype that can't fit in R's signed int32 -> "pseudo-double":
// REALSXP slots carry the int64 bit pattern (the storage layout used by
// bit64::integer64). The R caller attaches the integer64 class. Widening
// uint32_t to int64_t is value-preserving (u32 max < 2^53).
static_assert(sizeof(double) == sizeof(int64_t),
"bit64::integer64 layout requires sizeof(double) == "
"sizeof(int64_t)");
Expand Down Expand Up @@ -689,7 +691,10 @@ SEXP impl_raw_to_array(Rcpp::XPtr<rpjrt::PJRTHostData> host_data,
} else if (dtype == "ui16") {
return raw_to_array_impl<uint16_t>(raw_data, dimensions, INTSXP);
} else if (dtype == "ui32") {
return raw_to_array_impl<uint32_t>(raw_data, dimensions, INTSXP);
// u32 -> integer64 storage: signed int32 has no headroom for ui32 values
// >= 2^31; widen to integer64 (53 bits of headroom) so the R caller gets
// the full unsigned magnitude. R-side classes the result as integer64.
return raw_to_array_impl<uint32_t>(raw_data, dimensions, REALSXP);
} else if (dtype == "ui64") {
return raw_to_array_impl<uint64_t>(raw_data, dimensions, REALSXP);
} else if (dtype == "pred") {
Expand Down
Loading
Loading