Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,14 @@
`rbind()` / `cbind()` generics.
* New API function `nv_flatten()` for flattening to 1-D.

### NA scanning

* `nv_array()`, `nv_scalar()`, `as_array()`, and the `as.integer()` /
`as.double()` / `as.logical()` / `as.vector()` methods for
`AnvlArray` gained a `check` argument that opts into scanning for
`NA` values during host -> device and device -> host transfers. See
the "Gotchas" vignette.

### Misc

* New `AnvlArray` -> R `vector` converters: `as.numeric()`,
Expand Down
2 changes: 1 addition & 1 deletion R/api.R
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,7 @@ nv_broadcast_scalars <- function(...) {

target_shape <- non_scalar_shapes[[1L]]
if (!all(vapply(non_scalar_shapes, identical, logical(1L), target_shape))) {
shapes <- paste0(sapply(shapes, shape2string), sep = ", ")
shapes <- paste0(sapply(shapes, shape2string), collapse = ", ")
cli_abort(
"All non-scalar arrays must have the same shape, but got {shapes}. Use {.fn nv_broadcast_arrays} for general broadcasting." # nolint
)
Expand Down
65 changes: 54 additions & 11 deletions R/array.R
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,12 @@
#' default column-major order, mirroring [`base::matrix()`]'s `byrow`.
#' Only allowed when `data` is an R object — passing an existing
#' `AnvlArray` together with `byrow = TRUE` is an error.
#' @param check (`logical(1)`)\cr
#' If `TRUE`, error when `data` contains any `NA` values. XLA has no
#' representation for missing values, so they are otherwise silently
#' coerced to the closest available value of the target dtype (e.g. `NaN`
#' for floats, the bit pattern `-2147483648` for `i32`, `TRUE` for
#' `bool`). Defaults to `FALSE`. See the "Gotchas" vignette.
#' @return ([`AnvlArray`])
#' @examplesIf pjrt::plugins_downloaded()
#' # A 1-d array (vector) with shape (4). Default type for integers is `i32`
Expand Down Expand Up @@ -91,8 +97,25 @@ NULL

#' @rdname AnvlArray
#' @export
nv_array <- function(data, dtype = NULL, device = NULL, shape = NULL, ambiguous = NULL, backend = NULL, byrow = FALSE) {
nv_array <- function(
data,
dtype = NULL,
device = NULL,
shape = NULL,
ambiguous = NULL,
backend = NULL,
byrow = FALSE,
check = FALSE
) {
assert_flag(byrow)
assert_flag(check)
if (check && !is_anvl_array(data) && anyNA(data)) {
n_na <- sum(is.na(data))
cli_abort(c(
"Input {.arg data} contains {n_na} {.val NA} value{?s}, which {?has/have} no representation at the XLA level.",
i = "Replace or drop missing values before transferring, or set {.code check = FALSE} to skip this check."
))
}
if (is_anvl_array(data)) {
if (byrow) {
cli_abort("{.arg byrow} only applies when constructing an {.cls AnvlArray} from an R object.")
Expand Down Expand Up @@ -262,8 +285,16 @@ unwrap_if_array <- function(x) {

#' @rdname AnvlArray
#' @export
nv_scalar <- function(data, dtype = NULL, device = NULL, ambiguous = NULL, backend = NULL) {
nv_array(data, dtype = dtype, device = device, shape = integer(), ambiguous = ambiguous, backend = backend)
nv_scalar <- function(data, dtype = NULL, device = NULL, ambiguous = NULL, backend = NULL, check = FALSE) {
nv_array(
data,
dtype = dtype,
device = device,
shape = integer(),
ambiguous = ambiguous,
backend = backend,
check = check
)
}

infer_matrix_dim <- function(n, other, given) {
Expand Down Expand Up @@ -393,9 +424,19 @@ shape.AnvlArray <- function(x, ...) {
globals$backends[[x$backend]]$shape(x)
}

#' @rdname as_array
#' @param check (`logical(1)`)\cr
#' If `TRUE`, sanity-check the materialized R vector against losing
#' information across the device-to-host boundary, and abort if any
#' problematic value is detected. Forwarded to the backend; for the
#' `xla` backend the relevant cases are `i32`/`i64` values colliding
#' with the `NA` bit pattern and `ui64` values `>= 2^63` wrapping
#' through `bit64::integer64`. See [`pjrt::as_array.PJRTBuffer()`] for
#' the full list. Defaults to `FALSE`. See the "Gotchas" vignette.
#' @export
as_array.AnvlArray <- function(x, ...) {
globals$backends[[x$backend]]$as_array(x)
as_array.AnvlArray <- function(x, check = FALSE, ...) {
assert_flag(check)
globals$backends[[x$backend]]$as_array(x, check = check)
}

#' @export
Expand Down Expand Up @@ -427,6 +468,8 @@ await.AnvlArray <- function(x, ...) {
#' @param mode (`character(1)`)\cr
#' For `as.vector()` only. See [base::as.vector()]. Defaults to `"any"`,
#' meaning the natural R type for the array's dtype.
#' @param check (`logical(1)`)\cr
#' Forwarded to [`as_array()`]; see there for details.
#' @param ... Unused.
#' @return An R vector of the corresponding type (`double`, `integer`, or `logical`).
#' @examplesIf pjrt::plugins_downloaded()
Expand All @@ -441,33 +484,33 @@ NULL
#' @rdname as-AnvlArray
#' @method as.double AnvlArray
#' @export
as.double.AnvlArray <- function(x, ...) {
as.double.AnvlArray <- function(x, check = FALSE, ...) {
dt <- dtype(x)
if (!(inherits(dt, "FloatType") || inherits(dt, "IntegerType") || inherits(dt, "UIntegerType"))) {
cli_abort("{.fn as.double} requires a float or integer dtype, but got {.val {as.character(dt)}}.")
}
as.double(as_array(x))
as.double(as_array(x, check = check))
}

#' @rdname as-AnvlArray
#' @method as.integer AnvlArray
#' @export
as.integer.AnvlArray <- function(x, ...) {
as.integer.AnvlArray <- function(x, check = FALSE, ...) {
dt <- dtype(x)
if (!(inherits(dt, "IntegerType") || inherits(dt, "UIntegerType"))) {
cli_abort("{.fn as.integer} requires a (signed or unsigned) integer dtype, but got {.val {as.character(dt)}}.")
}
as.integer(as_array(x))
as.integer(as_array(x, check = check))
}

#' @rdname as-AnvlArray
#' @method as.logical AnvlArray
#' @export
as.logical.AnvlArray <- function(x, ...) {
as.logical.AnvlArray <- function(x, check = FALSE, ...) {
if (!inherits(dtype(x), "BooleanType")) {
cli_abort("{.fn as.logical} requires a {.val bool} dtype, but got {.val {as.character(dtype(x))}}.")
}
as.logical(as_array(x))
as.logical(as_array(x, check = check))
}

#' @rdname as-AnvlArray
Expand Down
2 changes: 1 addition & 1 deletion R/backend-quickr.R
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ AnvlBackendQuickr <- function() {
dtype = function(x) x$dtype,
shape = function(x) x$shape,
ambiguous = function(x) x$ambiguous,
as_array = function(x) x$data,
as_array = function(x, check) x$data,
as_raw = function(x, row_major) as.raw(x$data),
platform = function(x) "cpu",
device = function(x) quickr_device("cpu"),
Expand Down
2 changes: 1 addition & 1 deletion R/backend-xla.R
Original file line number Diff line number Diff line change
Expand Up @@ -319,7 +319,7 @@ AnvlBackendXla <- function() {
dtype = function(x) tengen::dtype(x$data),
shape = function(x) tengen::shape(x$data),
ambiguous = function(x) x$ambiguous,
as_array = function(x) tengen::as_array(x$data),
as_array = function(x, check) tengen::as_array(x$data, check = check),
as_raw = function(x, row_major) tengen::as_raw(x$data, row_major = row_major),
platform = function(x) pjrt::platform(x$data),
device = function(x) device(x$data),
Expand Down
7 changes: 5 additions & 2 deletions R/backend.R
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,10 @@
#' @param dtype (`function`)\cr Extracts the dtype from an AnvlArray.
#' @param shape (`function`)\cr Extracts the shape from an AnvlArray.
#' @param ambiguous (`function`)\cr Extracts the ambiguous flag from an AnvlArray.
#' @param as_array (`function`)\cr Converts an AnvlArray to an R array.
#' @param as_array (`function(x, check)`)\cr Converts an AnvlArray to an R
#' array. The `check` flag is forwarded from [`as_array()`]; backends may use
#' it to abort when materialization would lose information (e.g. ui64 values
#' wrapping through `bit64::integer64`). See [`pjrt::as_array.PJRTBuffer()`].
#' @param as_raw (`function`)\cr Converts an AnvlArray to raw bytes.
#' @param platform (`function`)\cr Returns the platform name (e.g. `"cpu"`).
#' @param device (`function`)\cr Returns the device object for an AnvlArray.
Expand Down Expand Up @@ -141,7 +144,7 @@ register_backend(
dtype = function(x) x$dtype,
shape = function(x) x$shape,
ambiguous = function(x) x$ambiguous,
as_array = function(x) x$data,
as_array = function(x, check) x$data,
as_raw = function(x, row_major) cli_abort("as_raw not supported for plain backend"),
platform = function(x) "cpu",
device = function(x) PlainDeviceCpu(),
Expand Down
19 changes: 17 additions & 2 deletions man/AnvlArray.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 4 additions & 1 deletion man/AnvlBackend.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 6 additions & 3 deletions man/as-AnvlArray.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 14 additions & 2 deletions man/as_array.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions pkgdown/_pkgdown.yml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,8 @@ navbar:
href: articles/random-numbers.html
- text: Type Promotion
href: articles/type-promotion.html
- text: Gotchas
href: articles/gotchas.html
- text: Efficiency
href: articles/efficiency.html
- text: FAQ
Expand Down
Loading
Loading