diff --git a/R/reverse.R b/R/reverse.R index 924c5448..39b6c853 100644 --- a/R/reverse.R +++ b/R/reverse.R @@ -173,6 +173,18 @@ compute_requirements <- function(graph, wrt) { # `graph$inputs`; filter them out. `gradient()` already rejects static # `wrt` entries, so this drop is safe. is_static <- graph$is_static_flat + if (!is.null(is_static) && any(requires_grad_all & is_static)) { + flat_argnames <- rep( + graph$in_tree$names, + times = vapply(graph$in_tree$nodes, tree_size, integer(1L)) + ) + bad <- unique(flat_argnames[requires_grad_all & is_static]) + cli_abort(c( + "Cannot compute gradient with respect to {.arg {bad}}.", + x = "{cli::qty(length(bad))}{?It was/They were} passed as {?a plain R value/plain R values}", + i = "{cli::qty(length(bad))}Pass {?it/them} as an {.cls AnvlArray}." + )) + } requires_grad <- if (is.null(is_static)) { requires_grad_all } else { diff --git a/tests/testthat/_snaps/reverse.md b/tests/testthat/_snaps/reverse.md index 8a04fe0a..eae5584e 100644 --- a/tests/testthat/_snaps/reverse.md +++ b/tests/testthat/_snaps/reverse.md @@ -47,3 +47,23 @@ ! Can only compute gradient with respect to float arrays. x Got i32 +# wrt arg passed as plain R literal errors clearly + + Code + jit(function() gradient(nv_log, wrt = "operand")(1))() + Condition + Error in `compute_requirements()`: + ! Cannot compute gradient with respect to `operand`. + x It was passed as a plain R value + i Pass it as an . + +--- + + Code + jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))() + Condition + Error in `compute_requirements()`: + ! Cannot compute gradient with respect to `x` and `y`. + x They were passed as plain R values + i Pass them as an . + diff --git a/tests/testthat/test-reverse.R b/tests/testthat/test-reverse.R index 481ad184..d9ad38a3 100644 --- a/tests/testthat/test-reverse.R +++ b/tests/testthat/test-reverse.R @@ -266,6 +266,15 @@ test_that("can only compute gradient w.r.t. float arrays", { }) }) +test_that("wrt arg passed as plain R literal errors clearly", { + expect_snapshot(error = TRUE, { + jit(function() gradient(nv_log, wrt = "operand")(1))() + }) + expect_snapshot(error = TRUE, { + jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))() + }) +}) + test_that("can differentiate through integer/bool functions", { f <- function(x) { x1 <- nv_convert(x, "i32")