From 4a96e58b999bc4764e8cbabd82cd87c9d5b9bcb2 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 29 Apr 2026 07:25:45 +0000 Subject: [PATCH 1/3] add top_k --- DESCRIPTION | 1 + NAMESPACE | 2 + NEWS.md | 2 + R/op-top_k.R | 81 +++++++++++++++++++++++++++++++ man/hlo_top_k.Rd | 32 ++++++++++++ tests/testthat/_snaps/op-top_k.md | 50 +++++++++++++++++++ tests/testthat/test-op-top_k.R | 80 ++++++++++++++++++++++++++++++ 7 files changed, 248 insertions(+) create mode 100644 R/op-top_k.R create mode 100644 man/hlo_top_k.Rd create mode 100644 tests/testthat/_snaps/op-top_k.md create mode 100644 tests/testthat/test-op-top_k.R diff --git a/DESCRIPTION b/DESCRIPTION index 4a0f0c1f..e1acab6c 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -138,6 +138,7 @@ Collate: 'op-subtract.R' 'op-tan.R' 'op-tanh.R' + 'op-top_k.R' 'op-transpose.R' 'op-triangular_solve.R' 'op-while.R' diff --git a/NAMESPACE b/NAMESPACE index 0901cede..210d6023 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -274,6 +274,7 @@ export(hlo_subtract) export(hlo_tan) export(hlo_tanh) export(hlo_tensor) +export(hlo_top_k) export(hlo_transpose) export(hlo_triangular_solve) export(hlo_while) @@ -372,6 +373,7 @@ export(infer_types_square) export(infer_types_subtract) export(infer_types_tan) export(infer_types_tanh) +export(infer_types_top_k) export(infer_types_transpose) export(infer_types_triangular_solve) export(infer_types_while) diff --git a/NEWS.md b/NEWS.md index 2e32e961..d5beb669 100644 --- a/NEWS.md +++ b/NEWS.md @@ -11,6 +11,8 @@ * Error / Bessel / misc: `hlo_erf()`, `hlo_erfc()`, `hlo_erf_inv()`, `hlo_bessel_i1e()`, `hlo_square()`. * Float predicates: `hlo_is_inf()`, `hlo_is_pos_inf()`, `hlo_is_neg_inf()`. + * Selection: `hlo_top_k()` returning the top-k values and their indices + along the last dimension. * `OpName()` and `new_Op()` gain a `dialect` argument (default `"stablehlo"`) to support ops from other MLIR dialects. diff --git a/R/op-top_k.R b/R/op-top_k.R new file mode 100644 index 00000000..d4aff605 --- /dev/null +++ b/R/op-top_k.R @@ -0,0 +1,81 @@ +#' @include op.R hlo.R +NULL + +OpTopK <- new_Op("OpTopK", "top_k", dialect = "chlo") + +#' @rdname hlo_top_k +#' @export +infer_types_top_k <- function(operand, k) { + assert_vt_is_tensor(operand) + assert_vt_has_ttype( + operand, + "FloatType", + "IntegerType", + "UIntegerType" + ) + assert_const(k, dtype = IntegerType(64L), shape = integer()) + k <- k$data + + operand_shape <- shape(operand) + rank <- length(operand_shape) + + if (rank < 1L) { + cli_abort(c( + "{.arg operand} must have rank >= 1.", + x = "Got rank {.val {rank}}." + )) + } + if (k < 0L) { + cli_abort(c( + "{.arg k} must be non-negative.", + x = "Got {.val {k}}." + )) + } + last_dim <- operand_shape[[rank]] + if (k > last_dim) { + cli_abort(c( + "{.arg k} must not exceed the size of the last dimension of {.arg operand}.", + x = "Got k = {.val {k}} and last dimension size {.val {last_dim}}." + )) + } + + result_shape <- operand_shape + result_shape[[rank]] <- as.integer(k) + + values_type <- ValueType( + TensorType(dtype = operand$type$dtype, shape = Shape(result_shape)) + ) + indices_type <- ValueType( + TensorType(dtype = IntegerType(32L), shape = Shape(result_shape)) + ) + + ValueTypes(list(values_type, indices_type)) +} + +hlo_top_k_impl <- hlo_fn(OpTopK, infer_types_top_k) + +#' @templateVar mnemonic top_k +#' @templateVar not_func_variables k +#' @template op_chlo +#' @param operand ([`FuncValue`])\cr +#' Tensor of integer, unsigned integer, or floating-point type with rank >= 1. +#' @param k (`integer(1)`)\cr +#' Number of top elements to return along the last dimension. Must satisfy +#' `0 <= k <= dim(operand, -1)`. +#' @return A `list()` of two [`FuncValue`]s: the top-k values (same dtype as +#' `operand`) and their indices into the last dimension (dtype `i32`). Ties +#' are broken by lower index first. +#' @export +hlo_top_k <- function(operand, k) { + hlo_top_k_impl( + values = list(operand = operand), + attrs = list( + ScalarAttr( + name = "k", + value = as.integer(k), + dtype = IntegerType(64L) + ) + ), + simplify = FALSE + ) +} diff --git a/man/hlo_top_k.Rd b/man/hlo_top_k.Rd new file mode 100644 index 00000000..bf984ecb --- /dev/null +++ b/man/hlo_top_k.Rd @@ -0,0 +1,32 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/op-top_k.R +\name{infer_types_top_k} +\alias{infer_types_top_k} +\alias{hlo_top_k} +\title{TopK Operator (CHLO)} +\usage{ +infer_types_top_k(operand, k) + +hlo_top_k(operand, k) +} +\arguments{ +\item{operand}{(\code{\link{FuncValue}})\cr +Tensor of integer, unsigned integer, or floating-point type with rank >= 1.} + +\item{k}{(\code{integer(1)})\cr +Number of top elements to return along the last dimension. Must satisfy +\verb{0 <= k <= dim(operand, -1)}.} +} +\value{ +\code{\link{FuncValue}}\cr + +A \code{list()} of two \code{\link{FuncValue}}s: the top-k values (same dtype as +\code{operand}) and their indices into the last dimension (dtype \code{i32}). Ties +are broken by lower index first. +} +\description{ +This op is from the CHLO dialect, a higher-level companion to stableHLO +that is lowered to stableHLO during compilation. See +\url{https://openxla.org/stablehlo/generated/chlo#chlotop_k_chlotop_kop} +for details. +} diff --git a/tests/testthat/_snaps/op-top_k.md b/tests/testthat/_snaps/op-top_k.md new file mode 100644 index 00000000..840d03a6 --- /dev/null +++ b/tests/testthat/_snaps/op-top_k.md @@ -0,0 +1,50 @@ +# basic tests + + Code + repr(fv) + Output + [1] "func.func @main (%x: tensor<2x5xf32>) -> tensor<2x3xf32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<2x5xf32>) -> (tensor<2x3xf32>, tensor<2x3xi32>)\nreturn %0 : tensor<2x3xf32>\n}\n" + +--- + + Code + repr(fi) + Output + [1] "func.func @main (%x: tensor<2x5xf32>) -> tensor<2x3xi32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<2x5xf32>) -> (tensor<2x3xf32>, tensor<2x3xi32>)\nreturn %1 : tensor<2x3xi32>\n}\n" + +# errors + + Code + infer_types_top_k(vt("f32", integer()), k = scnst(1L, "i64")) + Condition + Error in `infer_types_top_k()`: + ! `operand` must have rank >= 1. + x Got rank 0. + +--- + + Code + infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(5L, "i64")) + Condition + Error in `infer_types_top_k()`: + ! `k` must not exceed the size of the last dimension of `operand`. + x Got k = 5 and last dimension size 3. + +--- + + Code + infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(-1L, "i64")) + Condition + Error in `infer_types_top_k()`: + ! `k` must be non-negative. + x Got -1. + +--- + + Code + infer_types_top_k(vt("pred", c(2L, 3L)), k = scnst(1L, "i64")) + Condition + Error in `infer_types_top_k()`: + ! `operand` must have dtype FloatType, IntegerType, or UIntegerType. + x Got bool. + diff --git a/tests/testthat/test-op-top_k.R b/tests/testthat/test-op-top_k.R new file mode 100644 index 00000000..003fd289 --- /dev/null +++ b/tests/testthat/test-op-top_k.R @@ -0,0 +1,80 @@ +test_that("basic tests", { + func_values <- local_func() + xv <- hlo_input("x", "f32", shape = c(2L, 5L)) + resv <- hlo_top_k(xv, k = 3L) + fv <- hlo_return(resv[[1L]], func = func_values) + + func_indices <- local_func() + xi <- hlo_input("x", "f32", shape = c(2L, 5L)) + resi <- hlo_top_k(xi, k = 3L) + fi <- hlo_return(resi[[2L]], func = func_indices) + + expect_snapshot(repr(fv)) + expect_snapshot(repr(fi)) + + skip_if_not_installed("pjrt") + exec_v <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fv))) + exec_i <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fi))) + + data <- matrix(c(5, 1, 3, 2, 4, 9, 7, 8, 6, 10), nrow = 2L, byrow = TRUE) + buf <- pjrt::pjrt_buffer(data, dtype = "f32") + + out_v <- pjrt::pjrt_execute(exec_v, buf) + expect_equal( + pjrt::as_array(out_v), + matrix(c(5, 4, 3, 10, 9, 8), nrow = 2L, byrow = TRUE) + ) + + out_i <- pjrt::pjrt_execute(exec_i, buf) + expect_equal( + pjrt::as_array(out_i), + matrix(c(0L, 4L, 2L, 4L, 0L, 2L), nrow = 2L, byrow = TRUE) + ) +}) + +test_that("output types and shapes", { + # 1-D input + vt_out <- infer_types_top_k(vt("f32", 8L), k = scnst(3L, "i64")) + expect_length(vt_out, 2L) + expect_equal(shape(vt_out[[1L]]), 3L) + expect_equal(shape(vt_out[[2L]]), 3L) + expect_equal(vt_out[[1L]]$type$dtype, FloatType(32L)) + expect_equal(vt_out[[2L]]$type$dtype, IntegerType(32L)) + + # higher-rank input — only last dim changes + vt_out <- infer_types_top_k(vt("i64", c(2L, 3L, 7L)), k = scnst(2L, "i64")) + expect_equal(shape(vt_out[[1L]]), c(2L, 3L, 2L)) + expect_equal(shape(vt_out[[2L]]), c(2L, 3L, 2L)) + expect_equal(vt_out[[1L]]$type$dtype, IntegerType(64L)) + expect_equal(vt_out[[2L]]$type$dtype, IntegerType(32L)) + + # k = 0 is allowed + vt_out <- infer_types_top_k(vt("f64", 5L), k = scnst(0L, "i64")) + expect_equal(shape(vt_out[[1L]]), 0L) +}) + +test_that("errors", { + # rank 0 operand + expect_snapshot( + infer_types_top_k(vt("f32", integer()), k = scnst(1L, "i64")), + error = TRUE + ) + + # k > last dim + expect_snapshot( + infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(5L, "i64")), + error = TRUE + ) + + # negative k + expect_snapshot( + infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(-1L, "i64")), + error = TRUE + ) + + # unsupported dtype (boolean) + expect_snapshot( + infer_types_top_k(vt("pred", c(2L, 3L)), k = scnst(1L, "i64")), + error = TRUE + ) +}) From 061cf62ff24b4b21b1cbcf5857fce3d6eb9caff0 Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 29 Apr 2026 07:49:19 +0000 Subject: [PATCH 2/3] also support rank 1 --- tests/testthat/_snaps/op-top_k.md | 7 +++++++ tests/testthat/test-op-top_k.R | 28 ++++++++++++++++++++++++++++ 2 files changed, 35 insertions(+) diff --git a/tests/testthat/_snaps/op-top_k.md b/tests/testthat/_snaps/op-top_k.md index 840d03a6..05937553 100644 --- a/tests/testthat/_snaps/op-top_k.md +++ b/tests/testthat/_snaps/op-top_k.md @@ -12,6 +12,13 @@ Output [1] "func.func @main (%x: tensor<2x5xf32>) -> tensor<2x3xi32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<2x5xf32>) -> (tensor<2x3xf32>, tensor<2x3xi32>)\nreturn %1 : tensor<2x3xi32>\n}\n" +# works on rank-1 input + + Code + repr(fv1) + Output + [1] "func.func @main (%x: tensor<5xf32>) -> tensor<3xf32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<5xf32>) -> (tensor<3xf32>, tensor<3xi32>)\nreturn %0 : tensor<3xf32>\n}\n" + # errors Code diff --git a/tests/testthat/test-op-top_k.R b/tests/testthat/test-op-top_k.R index 003fd289..c27975c5 100644 --- a/tests/testthat/test-op-top_k.R +++ b/tests/testthat/test-op-top_k.R @@ -32,6 +32,34 @@ test_that("basic tests", { ) }) +test_that("works on rank-1 input", { + func1 <- local_func() + x1 <- hlo_input("x", "f32", shape = 5L) + res1 <- hlo_top_k(x1, k = 3L) + fv1 <- hlo_return(res1[[1L]], func = func1) + + func2 <- local_func() + x2 <- hlo_input("x", "f32", shape = 5L) + res2 <- hlo_top_k(x2, k = 3L) + fi1 <- hlo_return(res2[[2L]], func = func2) + + expect_snapshot(repr(fv1)) + + skip_if_not_installed("pjrt") + exec_v <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fv1))) + exec_i <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fi1))) + buf <- pjrt::pjrt_buffer(c(5, 1, 3, 2, 4), dtype = "f32") + + expect_equal( + pjrt::as_array(pjrt::pjrt_execute(exec_v, buf)), + array(c(5, 4, 3), dim = 3L) + ) + expect_equal( + pjrt::as_array(pjrt::pjrt_execute(exec_i, buf)), + array(c(0L, 4L, 2L), dim = 3L) + ) +}) + test_that("output types and shapes", { # 1-D input vt_out <- infer_types_top_k(vt("f32", 8L), k = scnst(3L, "i64")) From dad951b7b5dc83412f62056f8e11c72d2dab170a Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Wed, 29 Apr 2026 08:12:09 +0000 Subject: [PATCH 3/3] require k >= 1 in hlo_top_k A top-0 query is degenerate and almost always indicates a user bug, so reject it at inference time rather than producing empty tensors. Co-Authored-By: Claude Opus 4.7 (1M context) --- R/op-top_k.R | 6 +++--- man/hlo_top_k.Rd | 2 +- tests/testthat/_snaps/op-top_k.md | 6 +++--- tests/testthat/test-op-top_k.R | 33 +++++++++++++++---------------- 4 files changed, 23 insertions(+), 24 deletions(-) diff --git a/R/op-top_k.R b/R/op-top_k.R index d4aff605..349062bd 100644 --- a/R/op-top_k.R +++ b/R/op-top_k.R @@ -25,9 +25,9 @@ infer_types_top_k <- function(operand, k) { x = "Got rank {.val {rank}}." )) } - if (k < 0L) { + if (k < 1L) { cli_abort(c( - "{.arg k} must be non-negative.", + "{.arg k} must be a positive integer.", x = "Got {.val {k}}." )) } @@ -61,7 +61,7 @@ hlo_top_k_impl <- hlo_fn(OpTopK, infer_types_top_k) #' Tensor of integer, unsigned integer, or floating-point type with rank >= 1. #' @param k (`integer(1)`)\cr #' Number of top elements to return along the last dimension. Must satisfy -#' `0 <= k <= dim(operand, -1)`. +#' `1 <= k <= dim(operand, -1)`. #' @return A `list()` of two [`FuncValue`]s: the top-k values (same dtype as #' `operand`) and their indices into the last dimension (dtype `i32`). Ties #' are broken by lower index first. diff --git a/man/hlo_top_k.Rd b/man/hlo_top_k.Rd index bf984ecb..b46d2616 100644 --- a/man/hlo_top_k.Rd +++ b/man/hlo_top_k.Rd @@ -15,7 +15,7 @@ Tensor of integer, unsigned integer, or floating-point type with rank >= 1.} \item{k}{(\code{integer(1)})\cr Number of top elements to return along the last dimension. Must satisfy -\verb{0 <= k <= dim(operand, -1)}.} +\verb{1 <= k <= dim(operand, -1)}.} } \value{ \code{\link{FuncValue}}\cr diff --git a/tests/testthat/_snaps/op-top_k.md b/tests/testthat/_snaps/op-top_k.md index 05937553..a31b1049 100644 --- a/tests/testthat/_snaps/op-top_k.md +++ b/tests/testthat/_snaps/op-top_k.md @@ -40,11 +40,11 @@ --- Code - infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(-1L, "i64")) + infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(0L, "i64")) Condition Error in `infer_types_top_k()`: - ! `k` must be non-negative. - x Got -1. + ! `k` must be a positive integer. + x Got 0. --- diff --git a/tests/testthat/test-op-top_k.R b/tests/testthat/test-op-top_k.R index c27975c5..01714f18 100644 --- a/tests/testthat/test-op-top_k.R +++ b/tests/testthat/test-op-top_k.R @@ -19,16 +19,19 @@ test_that("basic tests", { data <- matrix(c(5, 1, 3, 2, 4, 9, 7, 8, 6, 10), nrow = 2L, byrow = TRUE) buf <- pjrt::pjrt_buffer(data, dtype = "f32") - out_v <- pjrt::pjrt_execute(exec_v, buf) expect_equal( - pjrt::as_array(out_v), - matrix(c(5, 4, 3, 10, 9, 8), nrow = 2L, byrow = TRUE) + pjrt::pjrt_execute(exec_v, buf), + pjrt::pjrt_buffer( + matrix(c(5, 4, 3, 10, 9, 8), nrow = 2L, byrow = TRUE), + dtype = "f32" + ) ) - - out_i <- pjrt::pjrt_execute(exec_i, buf) expect_equal( - pjrt::as_array(out_i), - matrix(c(0L, 4L, 2L, 4L, 0L, 2L), nrow = 2L, byrow = TRUE) + pjrt::pjrt_execute(exec_i, buf), + pjrt::pjrt_buffer( + matrix(c(0L, 4L, 2L, 4L, 0L, 2L), nrow = 2L, byrow = TRUE), + dtype = "i32" + ) ) }) @@ -51,12 +54,12 @@ test_that("works on rank-1 input", { buf <- pjrt::pjrt_buffer(c(5, 1, 3, 2, 4), dtype = "f32") expect_equal( - pjrt::as_array(pjrt::pjrt_execute(exec_v, buf)), - array(c(5, 4, 3), dim = 3L) + pjrt::pjrt_execute(exec_v, buf), + pjrt::pjrt_buffer(c(5, 4, 3), dtype = "f32") ) expect_equal( - pjrt::as_array(pjrt::pjrt_execute(exec_i, buf)), - array(c(0L, 4L, 2L), dim = 3L) + pjrt::pjrt_execute(exec_i, buf), + pjrt::pjrt_buffer(c(0L, 4L, 2L), dtype = "i32") ) }) @@ -75,10 +78,6 @@ test_that("output types and shapes", { expect_equal(shape(vt_out[[2L]]), c(2L, 3L, 2L)) expect_equal(vt_out[[1L]]$type$dtype, IntegerType(64L)) expect_equal(vt_out[[2L]]$type$dtype, IntegerType(32L)) - - # k = 0 is allowed - vt_out <- infer_types_top_k(vt("f64", 5L), k = scnst(0L, "i64")) - expect_equal(shape(vt_out[[1L]]), 0L) }) test_that("errors", { @@ -94,9 +93,9 @@ test_that("errors", { error = TRUE ) - # negative k + # k = 0 expect_snapshot( - infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(-1L, "i64")), + infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(0L, "i64")), error = TRUE )