From 8e9561835e11c779a0bef99beafa6f349dcb3327 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Fri, 1 May 2026 07:12:12 +0000 Subject: [PATCH 1/3] fix: clearer error when `wrt` arg is a plain R literal When a `wrt` argument is passed as a plain R numeric (e.g. `gradient(nv_log, wrt = "operand")(1)`), it is inlined as a constant during tracing rather than registered as a graph input. The gradient transform then produced zero outputs while `out_tree` still expected one leaf, surfacing as a cryptic `subscript out of bounds` from `unflatten.LeafNode`. Detect this in `compute_requirements()` and abort with a message naming the offending argument(s) and pointing to `nv_array()`. Co-Authored-By: Claude Opus 4.7 (1M context) --- R/reverse.R | 12 ++++++++++++ tests/testthat/_snaps/reverse.md | 20 ++++++++++++++++++++ tests/testthat/test-reverse.R | 9 +++++++++ 3 files changed, 41 insertions(+) diff --git a/R/reverse.R b/R/reverse.R index cb2075e21..404241f70 100644 --- a/R/reverse.R +++ b/R/reverse.R @@ -173,6 +173,18 @@ compute_requirements <- function(graph, wrt) { # `graph$inputs`; filter them out. `gradient()` already rejects static # `wrt` entries, so this drop is safe. is_static <- graph$is_static_flat + if (!is.null(is_static) && any(requires_grad_all & is_static)) { + flat_argnames <- rep( + graph$in_tree$names, + times = vapply(graph$in_tree$nodes, tree_size, integer(1L)) + ) + bad <- unique(flat_argnames[requires_grad_all & is_static]) + cli_abort(c( + "Cannot compute gradient with respect to {.arg {bad}}.", + x = "{cli::qty(length(bad))}{?It was/They were} passed as {?a plain R value/plain R values} and embedded as {?a constant/constants} in the traced graph, so {?it has/they have} no graph input to differentiate.", + i = "{cli::qty(length(bad))}Pass {?it/them} as an {.cls AnvlArray}, e.g. {.code nv_array(1, dtype = \"f32\")}." + )) + } requires_grad <- if (is.null(is_static)) { requires_grad_all } else { diff --git a/tests/testthat/_snaps/reverse.md b/tests/testthat/_snaps/reverse.md index 8a04fe0aa..8f2a7d39a 100644 --- a/tests/testthat/_snaps/reverse.md +++ b/tests/testthat/_snaps/reverse.md @@ -47,3 +47,23 @@ ! Can only compute gradient with respect to float arrays. x Got i32 +# wrt arg passed as plain R literal errors clearly + + Code + jit(function() gradient(nv_log, wrt = "operand")(1))() + Condition + Error in `compute_requirements()`: + ! Cannot compute gradient with respect to `operand`. + x It was passed as a plain R value and embedded as a constant in the traced graph, so it has no graph input to differentiate. + i Pass it as an , e.g. `nv_array(1, dtype = "f32")`. + +--- + + Code + jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))() + Condition + Error in `compute_requirements()`: + ! Cannot compute gradient with respect to `x` and `y`. + x They were passed as plain R values and embedded as constants in the traced graph, so they have no graph input to differentiate. + i Pass them as an , e.g. `nv_array(1, dtype = "f32")`. + diff --git a/tests/testthat/test-reverse.R b/tests/testthat/test-reverse.R index 481ad1845..d9ad38a34 100644 --- a/tests/testthat/test-reverse.R +++ b/tests/testthat/test-reverse.R @@ -266,6 +266,15 @@ test_that("can only compute gradient w.r.t. float arrays", { }) }) +test_that("wrt arg passed as plain R literal errors clearly", { + expect_snapshot(error = TRUE, { + jit(function() gradient(nv_log, wrt = "operand")(1))() + }) + expect_snapshot(error = TRUE, { + jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))() + }) +}) + test_that("can differentiate through integer/bool functions", { f <- function(x) { x1 <- nv_convert(x, "i32") From 8f3e649733db0095ee8190623bdec5aa3c0aff21 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 18 May 2026 19:27:41 +0200 Subject: [PATCH 2/3] better error message --- R/reverse.R | 4 +- tests/testthat/_snaps/reverse.new.md | 69 ++++++++++++++++++++++++++++ 2 files changed, 71 insertions(+), 2 deletions(-) create mode 100644 tests/testthat/_snaps/reverse.new.md diff --git a/R/reverse.R b/R/reverse.R index dbfd905e2..39b6c8536 100644 --- a/R/reverse.R +++ b/R/reverse.R @@ -181,8 +181,8 @@ compute_requirements <- function(graph, wrt) { bad <- unique(flat_argnames[requires_grad_all & is_static]) cli_abort(c( "Cannot compute gradient with respect to {.arg {bad}}.", - x = "{cli::qty(length(bad))}{?It was/They were} passed as {?a plain R value/plain R values} and embedded as {?a constant/constants} in the traced graph, so {?it has/they have} no graph input to differentiate.", - i = "{cli::qty(length(bad))}Pass {?it/them} as an {.cls AnvlArray}, e.g. {.code nv_array(1, dtype = \"f32\")}." + x = "{cli::qty(length(bad))}{?It was/They were} passed as {?a plain R value/plain R values}", + i = "{cli::qty(length(bad))}Pass {?it/them} as an {.cls AnvlArray}." )) } requires_grad <- if (is.null(is_static)) { diff --git a/tests/testthat/_snaps/reverse.new.md b/tests/testthat/_snaps/reverse.new.md new file mode 100644 index 000000000..eae5584e2 --- /dev/null +++ b/tests/testthat/_snaps/reverse.new.md @@ -0,0 +1,69 @@ +# wrt for non-array input: gradient + + Code + g <- gradient(nv_round, wrt = "method") + g(nv_scalar(1), method = "nearest_even") + Condition + Error in `check_wrt_arrayish()`: + ! Cannot compute gradient with respect to non-array argument. + x Got + +# wrt for non-array input: value_and_gradient + + Code + g <- value_and_gradient(nv_round, wrt = "method") + g(nv_scalar(1), method = "nearest_even") + Condition + Error in `check_wrt_arrayish()`: + ! Cannot compute gradient with respect to non-array argument. + x Got + +# wrt for nested non-array input: gradient + + Code + g <- gradient(f, wrt = "x") + g(x = list(nv_scalar(1), 2L)) + Condition + Error in `check_wrt_arrayish()`: + ! Can only compute gradient with respect to float arrays. + x Got i32 + +# wrt for nested non-array input: value_and_gradient + + Code + g <- value_and_gradient(f, wrt = "x") + g(x = list(nv_scalar(1), 2L)) + Condition + Error in `check_wrt_arrayish()`: + ! Can only compute gradient with respect to float arrays. + x Got i32 + +# can only compute gradient w.r.t. float arrays + + Code + gradient(nv_floor, wrt = "operand")(nv_scalar(1L)) + Condition + Error in `check_wrt_arrayish()`: + ! Can only compute gradient with respect to float arrays. + x Got i32 + +# wrt arg passed as plain R literal errors clearly + + Code + jit(function() gradient(nv_log, wrt = "operand")(1))() + Condition + Error in `compute_requirements()`: + ! Cannot compute gradient with respect to `operand`. + x It was passed as a plain R value + i Pass it as an . + +--- + + Code + jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))() + Condition + Error in `compute_requirements()`: + ! Cannot compute gradient with respect to `x` and `y`. + x They were passed as plain R values + i Pass them as an . + From 54bb962b8e4487ddd160b63ed1df1bf8c88c01a0 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Mon, 18 May 2026 19:46:25 +0200 Subject: [PATCH 3/3] snapshots accept --- tests/testthat/_snaps/reverse.md | 8 ++-- tests/testthat/_snaps/reverse.new.md | 69 ---------------------------- 2 files changed, 4 insertions(+), 73 deletions(-) delete mode 100644 tests/testthat/_snaps/reverse.new.md diff --git a/tests/testthat/_snaps/reverse.md b/tests/testthat/_snaps/reverse.md index 8f2a7d39a..eae5584e2 100644 --- a/tests/testthat/_snaps/reverse.md +++ b/tests/testthat/_snaps/reverse.md @@ -54,8 +54,8 @@ Condition Error in `compute_requirements()`: ! Cannot compute gradient with respect to `operand`. - x It was passed as a plain R value and embedded as a constant in the traced graph, so it has no graph input to differentiate. - i Pass it as an , e.g. `nv_array(1, dtype = "f32")`. + x It was passed as a plain R value + i Pass it as an . --- @@ -64,6 +64,6 @@ Condition Error in `compute_requirements()`: ! Cannot compute gradient with respect to `x` and `y`. - x They were passed as plain R values and embedded as constants in the traced graph, so they have no graph input to differentiate. - i Pass them as an , e.g. `nv_array(1, dtype = "f32")`. + x They were passed as plain R values + i Pass them as an . diff --git a/tests/testthat/_snaps/reverse.new.md b/tests/testthat/_snaps/reverse.new.md deleted file mode 100644 index eae5584e2..000000000 --- a/tests/testthat/_snaps/reverse.new.md +++ /dev/null @@ -1,69 +0,0 @@ -# wrt for non-array input: gradient - - Code - g <- gradient(nv_round, wrt = "method") - g(nv_scalar(1), method = "nearest_even") - Condition - Error in `check_wrt_arrayish()`: - ! Cannot compute gradient with respect to non-array argument. - x Got - -# wrt for non-array input: value_and_gradient - - Code - g <- value_and_gradient(nv_round, wrt = "method") - g(nv_scalar(1), method = "nearest_even") - Condition - Error in `check_wrt_arrayish()`: - ! Cannot compute gradient with respect to non-array argument. - x Got - -# wrt for nested non-array input: gradient - - Code - g <- gradient(f, wrt = "x") - g(x = list(nv_scalar(1), 2L)) - Condition - Error in `check_wrt_arrayish()`: - ! Can only compute gradient with respect to float arrays. - x Got i32 - -# wrt for nested non-array input: value_and_gradient - - Code - g <- value_and_gradient(f, wrt = "x") - g(x = list(nv_scalar(1), 2L)) - Condition - Error in `check_wrt_arrayish()`: - ! Can only compute gradient with respect to float arrays. - x Got i32 - -# can only compute gradient w.r.t. float arrays - - Code - gradient(nv_floor, wrt = "operand")(nv_scalar(1L)) - Condition - Error in `check_wrt_arrayish()`: - ! Can only compute gradient with respect to float arrays. - x Got i32 - -# wrt arg passed as plain R literal errors clearly - - Code - jit(function() gradient(nv_log, wrt = "operand")(1))() - Condition - Error in `compute_requirements()`: - ! Cannot compute gradient with respect to `operand`. - x It was passed as a plain R value - i Pass it as an . - ---- - - Code - jit(function() gradient(function(x, y) prim_add(x, y))(1, 2))() - Condition - Error in `compute_requirements()`: - ! Cannot compute gradient with respect to `x` and `y`. - x They were passed as plain R values - i Pass them as an . -