Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ Collate:
'op-subtract.R'
'op-tan.R'
'op-tanh.R'
'op-top_k.R'
'op-transpose.R'
'op-triangular_solve.R'
'op-while.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,7 @@ export(hlo_subtract)
export(hlo_tan)
export(hlo_tanh)
export(hlo_tensor)
export(hlo_top_k)
export(hlo_transpose)
export(hlo_triangular_solve)
export(hlo_while)
Expand Down Expand Up @@ -372,6 +373,7 @@ export(infer_types_square)
export(infer_types_subtract)
export(infer_types_tan)
export(infer_types_tanh)
export(infer_types_top_k)
export(infer_types_transpose)
export(infer_types_triangular_solve)
export(infer_types_while)
Expand Down
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
* Error / Bessel / misc: `hlo_erf()`, `hlo_erfc()`, `hlo_erf_inv()`,
`hlo_bessel_i1e()`, `hlo_square()`.
* Float predicates: `hlo_is_inf()`, `hlo_is_pos_inf()`, `hlo_is_neg_inf()`.
* Selection: `hlo_top_k()` returning the top-k values and their indices
along the last dimension.
* `OpName()` and `new_Op()` gain a `dialect` argument (default `"stablehlo"`)
to support ops from other MLIR dialects.

Expand Down
81 changes: 81 additions & 0 deletions R/op-top_k.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
#' @include op.R hlo.R
NULL

OpTopK <- new_Op("OpTopK", "top_k", dialect = "chlo")

#' @rdname hlo_top_k
#' @export
infer_types_top_k <- function(operand, k) {
assert_vt_is_tensor(operand)
assert_vt_has_ttype(
operand,
"FloatType",
"IntegerType",
"UIntegerType"
)
assert_const(k, dtype = IntegerType(64L), shape = integer())
k <- k$data

operand_shape <- shape(operand)
rank <- length(operand_shape)

if (rank < 1L) {
cli_abort(c(
"{.arg operand} must have rank >= 1.",
x = "Got rank {.val {rank}}."
))
}
if (k < 1L) {
cli_abort(c(
"{.arg k} must be a positive integer.",
x = "Got {.val {k}}."
))
}
last_dim <- operand_shape[[rank]]
if (k > last_dim) {
cli_abort(c(
"{.arg k} must not exceed the size of the last dimension of {.arg operand}.",
x = "Got k = {.val {k}} and last dimension size {.val {last_dim}}."
))
}

result_shape <- operand_shape
result_shape[[rank]] <- as.integer(k)

values_type <- ValueType(
TensorType(dtype = operand$type$dtype, shape = Shape(result_shape))
)
indices_type <- ValueType(
TensorType(dtype = IntegerType(32L), shape = Shape(result_shape))
)

ValueTypes(list(values_type, indices_type))
}

hlo_top_k_impl <- hlo_fn(OpTopK, infer_types_top_k)

#' @templateVar mnemonic top_k
#' @templateVar not_func_variables k
#' @template op_chlo
#' @param operand ([`FuncValue`])\cr
#' Tensor of integer, unsigned integer, or floating-point type with rank >= 1.
#' @param k (`integer(1)`)\cr
#' Number of top elements to return along the last dimension. Must satisfy
#' `1 <= k <= dim(operand, -1)`.
#' @return A `list()` of two [`FuncValue`]s: the top-k values (same dtype as
#' `operand`) and their indices into the last dimension (dtype `i32`). Ties
#' are broken by lower index first.
#' @export
hlo_top_k <- function(operand, k) {
hlo_top_k_impl(
values = list(operand = operand),
attrs = list(
ScalarAttr(
name = "k",
value = as.integer(k),
dtype = IntegerType(64L)
)
),
simplify = FALSE
)
}
32 changes: 32 additions & 0 deletions man/hlo_top_k.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

57 changes: 57 additions & 0 deletions tests/testthat/_snaps/op-top_k.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
# basic tests

Code
repr(fv)
Output
[1] "func.func @main (%x: tensor<2x5xf32>) -> tensor<2x3xf32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<2x5xf32>) -> (tensor<2x3xf32>, tensor<2x3xi32>)\nreturn %0 : tensor<2x3xf32>\n}\n"

---

Code
repr(fi)
Output
[1] "func.func @main (%x: tensor<2x5xf32>) -> tensor<2x3xi32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<2x5xf32>) -> (tensor<2x3xf32>, tensor<2x3xi32>)\nreturn %1 : tensor<2x3xi32>\n}\n"

# works on rank-1 input

Code
repr(fv1)
Output
[1] "func.func @main (%x: tensor<5xf32>) -> tensor<3xf32> {\n%0, %1 = \"chlo.top_k\" (%x) {\nk = 3 : i64\n}: (tensor<5xf32>) -> (tensor<3xf32>, tensor<3xi32>)\nreturn %0 : tensor<3xf32>\n}\n"

# errors

Code
infer_types_top_k(vt("f32", integer()), k = scnst(1L, "i64"))
Condition
Error in `infer_types_top_k()`:
! `operand` must have rank >= 1.
x Got rank 0.

---

Code
infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(5L, "i64"))
Condition
Error in `infer_types_top_k()`:
! `k` must not exceed the size of the last dimension of `operand`.
x Got k = 5 and last dimension size 3.

---

Code
infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(0L, "i64"))
Condition
Error in `infer_types_top_k()`:
! `k` must be a positive integer.
x Got 0.

---

Code
infer_types_top_k(vt("pred", c(2L, 3L)), k = scnst(1L, "i64"))
Condition
Error in `infer_types_top_k()`:
! `operand` must have dtype FloatType, IntegerType, or UIntegerType.
x Got bool.

107 changes: 107 additions & 0 deletions tests/testthat/test-op-top_k.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
test_that("basic tests", {
func_values <- local_func()
xv <- hlo_input("x", "f32", shape = c(2L, 5L))
resv <- hlo_top_k(xv, k = 3L)
fv <- hlo_return(resv[[1L]], func = func_values)

func_indices <- local_func()
xi <- hlo_input("x", "f32", shape = c(2L, 5L))
resi <- hlo_top_k(xi, k = 3L)
fi <- hlo_return(resi[[2L]], func = func_indices)

expect_snapshot(repr(fv))
expect_snapshot(repr(fi))

skip_if_not_installed("pjrt")
exec_v <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fv)))
exec_i <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fi)))

data <- matrix(c(5, 1, 3, 2, 4, 9, 7, 8, 6, 10), nrow = 2L, byrow = TRUE)
buf <- pjrt::pjrt_buffer(data, dtype = "f32")

expect_equal(
pjrt::pjrt_execute(exec_v, buf),
pjrt::pjrt_buffer(
matrix(c(5, 4, 3, 10, 9, 8), nrow = 2L, byrow = TRUE),
dtype = "f32"
)
)
expect_equal(
pjrt::pjrt_execute(exec_i, buf),
pjrt::pjrt_buffer(
matrix(c(0L, 4L, 2L, 4L, 0L, 2L), nrow = 2L, byrow = TRUE),
dtype = "i32"
)
)
})

test_that("works on rank-1 input", {
func1 <- local_func()
x1 <- hlo_input("x", "f32", shape = 5L)
res1 <- hlo_top_k(x1, k = 3L)
fv1 <- hlo_return(res1[[1L]], func = func1)

func2 <- local_func()
x2 <- hlo_input("x", "f32", shape = 5L)
res2 <- hlo_top_k(x2, k = 3L)
fi1 <- hlo_return(res2[[2L]], func = func2)

expect_snapshot(repr(fv1))

skip_if_not_installed("pjrt")
exec_v <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fv1)))
exec_i <- pjrt::pjrt_compile(pjrt::pjrt_program(repr(fi1)))
buf <- pjrt::pjrt_buffer(c(5, 1, 3, 2, 4), dtype = "f32")

expect_equal(
pjrt::pjrt_execute(exec_v, buf),
pjrt::pjrt_buffer(c(5, 4, 3), dtype = "f32")
)
expect_equal(
pjrt::pjrt_execute(exec_i, buf),
pjrt::pjrt_buffer(c(0L, 4L, 2L), dtype = "i32")
)
})

test_that("output types and shapes", {
# 1-D input
vt_out <- infer_types_top_k(vt("f32", 8L), k = scnst(3L, "i64"))
expect_length(vt_out, 2L)
expect_equal(shape(vt_out[[1L]]), 3L)
expect_equal(shape(vt_out[[2L]]), 3L)
expect_equal(vt_out[[1L]]$type$dtype, FloatType(32L))
expect_equal(vt_out[[2L]]$type$dtype, IntegerType(32L))

# higher-rank input — only last dim changes
vt_out <- infer_types_top_k(vt("i64", c(2L, 3L, 7L)), k = scnst(2L, "i64"))
expect_equal(shape(vt_out[[1L]]), c(2L, 3L, 2L))
expect_equal(shape(vt_out[[2L]]), c(2L, 3L, 2L))
expect_equal(vt_out[[1L]]$type$dtype, IntegerType(64L))
expect_equal(vt_out[[2L]]$type$dtype, IntegerType(32L))
})

test_that("errors", {
# rank 0 operand
expect_snapshot(
infer_types_top_k(vt("f32", integer()), k = scnst(1L, "i64")),
error = TRUE
)

# k > last dim
expect_snapshot(
infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(5L, "i64")),
error = TRUE
)

# k = 0
expect_snapshot(
infer_types_top_k(vt("f32", c(2L, 3L)), k = scnst(0L, "i64")),
error = TRUE
)

# unsupported dtype (boolean)
expect_snapshot(
infer_types_top_k(vt("pred", c(2L, 3L)), k = scnst(1L, "i64")),
error = TRUE
)
})
Loading