Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions R/reverse.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,18 @@ compute_requirements <- function(graph, wrt) {
# `graph$inputs`; filter them out. `gradient()` already rejects static
# `wrt` entries, so this drop is safe.
is_static <- graph$is_static_flat
if (!is.null(is_static) && any(requires_grad_all & is_static)) {
flat_argnames <- rep(
graph$in_tree$names,
times = vapply(graph$in_tree$nodes, tree_size, integer(1L))
)
bad <- unique(flat_argnames[requires_grad_all & is_static])
cli_abort(c(
"Cannot compute gradient with respect to {.arg {bad}}.",
x = "{cli::qty(length(bad))}{?It was/They were} passed as {?a plain R value/plain R values}",
i = "{cli::qty(length(bad))}Pass {?it/them} as an {.cls AnvlArray}."
))
}
requires_grad <- if (is.null(is_static)) {
requires_grad_all
} else {
Expand Down
20 changes: 20 additions & 0 deletions tests/testthat/_snaps/reverse.md
Original file line number Diff line number Diff line change
Expand Up @@ -47,3 +47,23 @@
! Can only compute gradient with respect to float arrays.
x Got i32

# wrt arg passed as plain R literal errors clearly

Code
jit(function() gradient(nv_log, wrt = "operand")(1))()
Condition
Error in `compute_requirements()`:
! Cannot compute gradient with respect to `operand`.
x It was passed as a plain R value
i Pass it as an <AnvlArray>.

---

Code
jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))()
Condition
Error in `compute_requirements()`:
! Cannot compute gradient with respect to `x` and `y`.
x They were passed as plain R values
i Pass them as an <AnvlArray>.

9 changes: 9 additions & 0 deletions tests/testthat/test-reverse.R
Original file line number Diff line number Diff line change
Expand Up @@ -266,6 +266,15 @@ test_that("can only compute gradient w.r.t. float arrays", {
})
})

test_that("wrt arg passed as plain R literal errors clearly", {
expect_snapshot(error = TRUE, {
jit(function() gradient(nv_log, wrt = "operand")(1))()
})
expect_snapshot(error = TRUE, {
jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))()
})
})

test_that("can differentiate through integer/bool functions", {
f <- function(x) {
x1 <- nv_convert(x, "i32")
Expand Down
Loading