From c917f9b278cc57198a541c34331b72ee0ba5a3ee Mon Sep 17 00:00:00 2001 From: Sebastian Fischer Date: Tue, 28 Apr 2026 14:34:58 +0000 Subject: [PATCH] feat(op): add hlo_fft for forward/inverse Fourier transforms Implements stablehlo.fft per SPEC.md, supporting all four variants (FFT, IFFT, RFFT, IRFFT) with full constraint checking (C1-C5). Re-exports ComplexType from tengen so users can refer to complex (c64) and complex (c128) tensors. Depends on the matching ComplexType addition in tengen. Co-Authored-By: Claude Opus 4.7 (1M context) --- DESCRIPTION | 1 + NAMESPACE | 6 + NEWS.md | 4 + R/op-fft.R | 198 ++++++++++++++++++++++++++++++++ R/types.R | 8 ++ man/hlo_fft.Rd | 30 +++++ man/reexports.Rd | 3 +- tests/testthat/_snaps/op-fft.md | 110 ++++++++++++++++++ tests/testthat/test-op-fft.R | 152 ++++++++++++++++++++++++ 9 files changed, 511 insertions(+), 1 deletion(-) create mode 100644 R/op-fft.R create mode 100644 man/hlo_fft.Rd create mode 100644 tests/testthat/_snaps/op-fft.md create mode 100644 tests/testthat/test-op-fft.R diff --git a/DESCRIPTION b/DESCRIPTION index 4a0f0c1f..f1244002 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -91,6 +91,7 @@ Collate: 'op-erfc.R' 'op-exponential.R' 'op-exponential_minus_one.R' + 'op-fft.R' 'op-floor.R' 'op-gather.R' 'op-if.R' diff --git a/NAMESPACE b/NAMESPACE index 0901cede..6d3fb1a6 100644 --- a/NAMESPACE +++ b/NAMESPACE @@ -63,6 +63,7 @@ S3method(r_to_constant,logical) S3method(repr,"NULL") S3method(repr,BoolAttr) S3method(repr,BooleanType) +S3method(repr,ComplexType) S3method(repr,Constant) S3method(repr,ConstantAttr) S3method(repr,CustomOpBackendConfig) @@ -84,6 +85,7 @@ S3method(repr,OpCompare) S3method(repr,OpConstant) S3method(repr,OpCustomCall) S3method(repr,OpDotGeneral) +S3method(repr,OpFft) S3method(repr,OpGather) S3method(repr,OpInputAttrs) S3method(repr,OpInputFunc) @@ -132,6 +134,7 @@ export(.current_func) export(.current_module) export(BoolAttr) export(BooleanType) +export(ComplexType) export(Constant) export(ConstantAttr) export(CustomOpBackendConfig) @@ -222,6 +225,7 @@ export(hlo_erf_inv) export(hlo_erfc) export(hlo_exponential) export(hlo_exponential_minus_one) +export(hlo_fft) export(hlo_floor) export(hlo_func) export(hlo_gather) @@ -316,6 +320,7 @@ export(infer_types_erf_inv) export(infer_types_erfc) export(infer_types_exponential) export(infer_types_exponential_minus_one) +export(infer_types_fft) export(infer_types_float_biv) export(infer_types_float_uni) export(infer_types_floor) @@ -391,6 +396,7 @@ importFrom(cli,format_error) importFrom(methods,is) importFrom(stats,setNames) importFrom(tengen,BooleanType) +importFrom(tengen,ComplexType) importFrom(tengen,FloatType) importFrom(tengen,IntegerType) importFrom(tengen,UIntegerType) diff --git a/NEWS.md b/NEWS.md index 2e32e961..4a554906 100644 --- a/NEWS.md +++ b/NEWS.md @@ -2,6 +2,10 @@ ## Features +* Added `hlo_fft()` for forward and inverse Fourier transforms (`FFT`, `IFFT`, + `RFFT`, `IRFFT`). +* Re-export `ComplexType()` from tengen for representing `complex` / + `complex` (a.k.a. `c64` / `c128`) tensors. * Added support for CHLO ops, a higher-level companion dialect to stableHLO that is lowered to stableHLO during compilation. New ops: * Inverse trig: `hlo_acos()`, `hlo_asin()`, `hlo_atan()`. diff --git a/R/op-fft.R b/R/op-fft.R new file mode 100644 index 00000000..8fd0d22e --- /dev/null +++ b/R/op-fft.R @@ -0,0 +1,198 @@ +#' @include op.R hlo.R +NULL + +OpFft <- new_Op("OpFft", "fft") + +# Map fft_type to (operand_kind, result_kind) where each is "real" or "complex" +# Used by both inference and the public function for clarity. +.fft_type_kinds <- list( + FFT = list(operand = "complex", result = "complex"), + IFFT = list(operand = "complex", result = "complex"), + RFFT = list(operand = "real", result = "complex"), + IRFFT = list(operand = "complex", result = "real") +) + +.fft_real_to_complex <- function(real_dtype) { + ComplexType(real_dtype$value) +} + +.fft_complex_to_real <- function(complex_dtype) { + FloatType(complex_dtype$value) +} + +#' @rdname hlo_fft +#' @export +infer_types_fft <- function(operand, fft_type, fft_length) { + assert_vt_is_tensor(operand) + + if (!test_choice(fft_type, names(.fft_type_kinds))) { + cli_abort(c( + "{.arg fft_type} must be one of {.val {names(.fft_type_kinds)}}.", + x = "Got {.val {fft_type}}." + )) + } + kinds <- .fft_type_kinds[[fft_type]] + + assert_const(fft_length, dtype = IntegerType(64L), ndims = 1L) + fft_len <- as.integer(fft_length$data) + L <- length(fft_len) + + # (C3) 1 <= size(fft_length) <= 3. + if (L < 1L || L > 3L) { + cli_abort(c( + "{.arg fft_length} must have 1 to 3 elements.", + x = "Got {.val {L}}." + )) + } + + operand_dtype <- operand$type$dtype + operand_shape <- shape(operand) + rank <- length(operand_shape) + + # (C2) operand element type constraint. + if (kinds$operand == "real") { + if (!test_class(operand_dtype, "FloatType")) { + cli_abort(c( + "{.arg operand} must have a floating-point dtype for {.val {fft_type}}.", + x = "Got {.val {operand_dtype}}." + )) + } + result_dtype <- .fft_real_to_complex(operand_dtype) + } else { + if (!test_class(operand_dtype, "ComplexType")) { + cli_abort(c( + "{.arg operand} must have a complex dtype for {.val {fft_type}}.", + x = "Got {.val {operand_dtype}}." + )) + } + result_dtype <- if (kinds$result == "complex") { + operand_dtype + } else { + .fft_complex_to_real(operand_dtype) + } + } + + # (C1) size(fft_length) <= rank(operand). + if (L > rank) { + cli_abort(c( + "{.arg fft_length} length must not exceed rank of {.arg operand}.", + x = "Got fft_length of length {L} and operand of rank {rank}." + )) + } + + fft_axes <- seq.int(rank - L + 1L, rank) + + # (C4) For the tensor of floating-point type among operand/result, + # shape(real)[-L:] must equal fft_length. + if (kinds$operand == "real") { + real_tail <- operand_shape[fft_axes] + if (!identical(real_tail, fft_len)) { + cli_abort(c( + "{.arg operand} trailing {L} dimension{?s} must equal {.arg fft_length}.", + x = "Got operand shape {shapevec_repr(operand_shape)} and fft_length = {vec_repr(fft_len)}." + )) + } + } + + # (C5) shape(result) = shape(operand) except for special cases. + result_shape <- operand_shape + if (fft_type == "RFFT") { + last <- operand_shape[[rank]] + result_shape[[rank]] <- if (last == 0L) 0L else as.integer(last %/% 2L + 1L) + } else if (fft_type == "IRFFT") { + # (C4) shape(result)[-L:] must equal fft_length except for the last dim, + # where dim(operand, -1) = fft_length[L] / 2 + 1. + expected_last_operand <- if (fft_len[[L]] == 0L) { + 0L + } else { + as.integer(fft_len[[L]] %/% 2L + 1L) + } + operand_tail_lead <- operand_shape[fft_axes[seq_len(L - 1L)]] + fft_lead <- fft_len[seq_len(L - 1L)] + if ( + !identical(operand_tail_lead, fft_lead) || + operand_shape[[rank]] != expected_last_operand + ) { + cli_abort(c( + "{.arg operand} trailing dimensions must match {.arg fft_length}.", + x = paste0( + "Got operand shape {shapevec_repr(operand_shape)} and fft_length = ", + "{vec_repr(fft_len)}; ", + "expected last operand dim {.val {expected_last_operand}} ", + "(= fft_length[{L}] / 2 + 1)." + ) + )) + } + result_shape[fft_axes] <- fft_len + } + + ValueTypes(list( + ValueType( + TensorType( + dtype = result_dtype, + shape = Shape(result_shape) + ) + ) + )) +} + +hlo_fft_impl <- hlo_fn(OpFft, infer_types_fft) + +#' @templateVar mnemonic fft +#' @template op +#' @param operand (`FuncValue`)\cr +#' Tensor of floating-point or complex type. +#' @param fft_type (`character(1)`)\cr +#' One of `"FFT"`, `"IFFT"`, `"RFFT"`, `"IRFFT"`. +#' @param fft_length (`integer()`)\cr +#' Length 1, 2, or 3 vector of `i64` values giving the FFT lengths along +#' the trailing dimensions. +#' @export +hlo_fft <- function( + operand, + fft_type = c("FFT", "IFFT", "RFFT", "IRFFT"), + fft_length +) { + fft_type <- match.arg(fft_type) + fft_length <- as.integer(fft_length) + + hlo_fft_impl( + values = list(operand = operand), + attrs = list( + constant_attr( + "fft_length", + fft_length, + dtype = "i64", + shape = length(fft_length) + ) + ), + custom_attrs = list(fft_type = fft_type) + ) +} + +#' @export +repr.OpFft <- function(x, simplify_dense = TRUE, ...) { + attrs_repr <- vapply( + x$inputs$attrs, + function(a) repr(a, simplify_dense = simplify_dense), + character(1) + ) + fft_type_repr <- paste0( + "fft_type = #stablehlo" + ) + all_attrs <- paste(c(fft_type_repr, attrs_repr), collapse = ",\n") + + paste0( + repr(x$outputs), + " = ", + repr(x$name), + " (", + repr(x$inputs$values), + ") {\n", + all_attrs, + "\n}: ", + repr(x$signature) + ) +} diff --git a/R/types.R b/R/types.R index 2d2371ca..41f3cba1 100644 --- a/R/types.R +++ b/R/types.R @@ -13,6 +13,9 @@ tengen::UIntegerType #' @export tengen::FloatType +#' @export +tengen::ComplexType + #' @export tengen::is_dtype @@ -39,6 +42,11 @@ repr.FloatType <- function(x, ...) { as.character(x) } +#' @export +repr.ComplexType <- function(x, ...) { + paste0("complex") +} + # Re-export assert_dtype from tengen assert_dtype <- tengen::assert_dtype diff --git a/man/hlo_fft.Rd b/man/hlo_fft.Rd new file mode 100644 index 00000000..40a1599c --- /dev/null +++ b/man/hlo_fft.Rd @@ -0,0 +1,30 @@ +% Generated by roxygen2: do not edit by hand +% Please edit documentation in R/op-fft.R +\name{infer_types_fft} +\alias{infer_types_fft} +\alias{hlo_fft} +\title{Fft Operator} +\usage{ +infer_types_fft(operand, fft_type, fft_length) + +hlo_fft(operand, fft_type = c("FFT", "IFFT", "RFFT", "IRFFT"), fft_length) +} +\arguments{ +\item{operand, fft_type, fft_length}{(\code{\link{FuncValue}})\cr} + +\item{fft_type}{(\code{character(1)})\cr +One of \code{"FFT"}, \code{"IFFT"}, \code{"RFFT"}, \code{"IRFFT"}.} + +\item{fft_length}{(\code{integer()})\cr +Length 1, 2, or 3 vector of \code{i64} values giving the FFT lengths along +the trailing dimensions.} + +\item{operand}{(\code{FuncValue})\cr +Tensor of floating-point or complex type.} +} +\value{ +\code{\link{FuncValue}}\cr +} +\description{ +See \url{https://openxla.org/stablehlo/spec#fft} for details. +} diff --git a/man/reexports.Rd b/man/reexports.Rd index 3b8c5290..ae7c17d0 100644 --- a/man/reexports.Rd +++ b/man/reexports.Rd @@ -7,6 +7,7 @@ \alias{IntegerType} \alias{UIntegerType} \alias{FloatType} +\alias{ComplexType} \alias{is_dtype} \alias{as_dtype} \alias{shape} @@ -18,6 +19,6 @@ These objects are imported from other packages. Follow the links below to see their documentation. \describe{ - \item{tengen}{\code{\link[tengen]{as_dtype}}, \code{\link[tengen]{BooleanType}}, \code{\link[tengen]{dtype}}, \code{\link[tengen]{FloatType}}, \code{\link[tengen]{IntegerType}}, \code{\link[tengen]{is_dtype}}, \code{\link[tengen]{shape}}, \code{\link[tengen]{UIntegerType}}} + \item{tengen}{\code{\link[tengen]{as_dtype}}, \code{\link[tengen]{BooleanType}}, \code{\link[tengen]{ComplexType}}, \code{\link[tengen]{dtype}}, \code{\link[tengen]{FloatType}}, \code{\link[tengen]{IntegerType}}, \code{\link[tengen]{is_dtype}}, \code{\link[tengen]{shape}}, \code{\link[tengen]{UIntegerType}}} }} diff --git a/tests/testthat/_snaps/op-fft.md b/tests/testthat/_snaps/op-fft.md new file mode 100644 index 00000000..b292c984 --- /dev/null +++ b/tests/testthat/_snaps/op-fft.md @@ -0,0 +1,110 @@ +# FFT (complex -> complex) repr + + Code + repr(f) + Output + [1] "func.func @main (%x: tensor<2x4xcomplex>) -> tensor<2x4xcomplex> {\n%0 = \"stablehlo.fft\" (%x) {\nfft_type = #stablehlo,\nfft_length = array\n}: (tensor<2x4xcomplex>) -> (tensor<2x4xcomplex>)\nreturn %0 : tensor<2x4xcomplex>\n}\n" + +# IFFT (complex -> complex) repr + + Code + repr(f) + Output + [1] "func.func @main (%x: tensor<3x4x8xcomplex>) -> tensor<3x4x8xcomplex> {\n%0 = \"stablehlo.fft\" (%x) {\nfft_type = #stablehlo,\nfft_length = array\n}: (tensor<3x4x8xcomplex>) -> (tensor<3x4x8xcomplex>)\nreturn %0 : tensor<3x4x8xcomplex>\n}\n" + +# RFFT (real -> complex) repr and shape + + Code + repr(f) + Output + [1] "func.func @main (%x: tensor<2x4xf32>) -> tensor<2x3xcomplex> {\n%0 = \"stablehlo.fft\" (%x) {\nfft_type = #stablehlo,\nfft_length = array\n}: (tensor<2x4xf32>) -> (tensor<2x3xcomplex>)\nreturn %0 : tensor<2x3xcomplex>\n}\n" + +# IRFFT (complex -> real) repr and shape + + Code + repr(f) + Output + [1] "func.func @main (%x: tensor<2x3xcomplex>) -> tensor<2x4xf32> {\n%0 = \"stablehlo.fft\" (%x) {\nfft_type = #stablehlo,\nfft_length = array\n}: (tensor<2x3xcomplex>) -> (tensor<2x4xf32>)\nreturn %0 : tensor<2x4xf32>\n}\n" + +# errors + + Code + infer_types_fft(vt("c64", 4L), "INVALID", cnst(4L, "i64", 1L)) + Condition + Error in `infer_types_fft()`: + ! `fft_type` must be one of "FFT", "IFFT", "RFFT", and "IRFFT". + x Got "INVALID". + +--- + + Code + infer_types_fft(vt("f32", 4L), "FFT", cnst(4L, "i64", 1L)) + Condition + Error in `infer_types_fft()`: + ! `operand` must have a complex dtype for "FFT". + x Got f32. + +--- + + Code + infer_types_fft(vt("c64", 4L), "RFFT", cnst(4L, "i64", 1L)) + Condition + Error in `infer_types_fft()`: + ! `operand` must have a floating-point dtype for "RFFT". + x Got c64. + +--- + + Code + infer_types_fft(vt("f32", 4L), "IRFFT", cnst(4L, "i64", 1L)) + Condition + Error in `infer_types_fft()`: + ! `operand` must have a complex dtype for "IRFFT". + x Got f32. + +--- + + Code + infer_types_fft(vt("c64", c(2L, 2L, 2L, 4L)), "FFT", cnst(c(2L, 2L, 2L, 4L), + "i64", 4L)) + Condition + Error in `infer_types_fft()`: + ! `fft_length` must have 1 to 3 elements. + x Got 4. + +--- + + Code + infer_types_fft(vt("c64", 4L), "FFT", cnst(integer(), "i64", 0L)) + Condition + Error in `infer_types_fft()`: + ! `fft_length` must have 1 to 3 elements. + x Got 0. + +--- + + Code + infer_types_fft(vt("c64", 4L), "FFT", cnst(c(2L, 4L), "i64", 2L)) + Condition + Error in `infer_types_fft()`: + ! `fft_length` length must not exceed rank of `operand`. + x Got fft_length of length 2 and operand of rank 1. + +--- + + Code + infer_types_fft(vt("f32", c(2L, 4L)), "RFFT", cnst(8L, "i64", 1L)) + Condition + Error in `infer_types_fft()`: + ! `operand` trailing 1 dimension must equal `fft_length`. + x Got operand shape (2x4) and fft_length = 8. + +--- + + Code + infer_types_fft(vt("c64", c(2L, 4L)), "IRFFT", cnst(8L, "i64", 1L)) + Condition + Error in `infer_types_fft()`: + ! `operand` trailing dimensions must match `fft_length`. + x Got operand shape (2x4) and fft_length = 8; expected last operand dim 5 (= fft_length[1] / 2 + 1). + diff --git a/tests/testthat/test-op-fft.R b/tests/testthat/test-op-fft.R new file mode 100644 index 00000000..8867ad9b --- /dev/null +++ b/tests/testthat/test-op-fft.R @@ -0,0 +1,152 @@ +test_that("FFT (complex -> complex) repr", { + local_func() + x <- hlo_input("x", "c64", shape = c(2L, 4L)) + y <- hlo_fft(x, "FFT", fft_length = 4L) + f <- hlo_return(y) + expect_snapshot(repr(f)) +}) + +test_that("IFFT (complex -> complex) repr", { + local_func() + x <- hlo_input("x", "c128", shape = c(3L, 4L, 8L)) + y <- hlo_fft(x, "IFFT", fft_length = c(4L, 8L)) + f <- hlo_return(y) + expect_snapshot(repr(f)) +}) + +test_that("RFFT (real -> complex) repr and shape", { + local_func() + x <- hlo_input("x", "f32", shape = c(2L, 4L)) + y <- hlo_fft(x, "RFFT", fft_length = 4L) + expect_equal(shape(y), c(2L, 3L)) + expect_equal(dtype(y), ComplexType(32L)) + f <- hlo_return(y) + expect_snapshot(repr(f)) +}) + +test_that("IRFFT (complex -> real) repr and shape", { + local_func() + x <- hlo_input("x", "c64", shape = c(2L, 3L)) + y <- hlo_fft(x, "IRFFT", fft_length = 4L) + expect_equal(shape(y), c(2L, 4L)) + expect_equal(dtype(y), FloatType(32L)) + f <- hlo_return(y) + expect_snapshot(repr(f)) +}) + +test_that("RFFT -> IRFFT round-trip executes", { + skip_if_not_installed("pjrt") + local_func() + x <- hlo_input("x", "f32", shape = 8L) + y <- hlo_fft(x, "RFFT", fft_length = 8L) + z <- hlo_fft(y, "IRFFT", fft_length = 8L) + f <- hlo_return(z) + + program <- pjrt::pjrt_program(repr(f)) + exe <- pjrt::pjrt_compile(program) + input <- array(as.numeric(seq_len(8)), dim = 8L) + inp_buf <- pjrt::pjrt_buffer(input, dtype = "f32") + out_buf <- pjrt::pjrt_execute(exe, inp_buf) + expect_equal(shape(out_buf), 8L) + expect_equal(pjrt::as_array(out_buf), input, tolerance = 1e-4) +}) + +test_that("multi-dim RFFT collapses last dim", { + local_func() + x <- hlo_input("x", "f64", shape = c(3L, 4L, 8L)) + y <- hlo_fft(x, "RFFT", fft_length = c(4L, 8L)) + expect_equal(shape(y), c(3L, 4L, 5L)) + expect_equal(dtype(y), ComplexType(64L)) +}) + +test_that("errors", { + # invalid fft_type (C2) + expect_snapshot( + infer_types_fft( + vt("c64", 4L), + "INVALID", + cnst(4L, "i64", 1L) + ), + error = TRUE + ) + + # FFT requires complex operand (C2) + expect_snapshot( + infer_types_fft( + vt("f32", 4L), + "FFT", + cnst(4L, "i64", 1L) + ), + error = TRUE + ) + + # RFFT requires float operand (C2) + expect_snapshot( + infer_types_fft( + vt("c64", 4L), + "RFFT", + cnst(4L, "i64", 1L) + ), + error = TRUE + ) + + # IRFFT requires complex operand (C2) + expect_snapshot( + infer_types_fft( + vt("f32", 4L), + "IRFFT", + cnst(4L, "i64", 1L) + ), + error = TRUE + ) + + # fft_length too long (C3) + expect_snapshot( + infer_types_fft( + vt("c64", c(2L, 2L, 2L, 4L)), + "FFT", + cnst(c(2L, 2L, 2L, 4L), "i64", 4L) + ), + error = TRUE + ) + + # fft_length empty (C3) + expect_snapshot( + infer_types_fft( + vt("c64", 4L), + "FFT", + cnst(integer(), "i64", 0L) + ), + error = TRUE + ) + + # fft_length longer than rank (C1) + expect_snapshot( + infer_types_fft( + vt("c64", 4L), + "FFT", + cnst(c(2L, 4L), "i64", 2L) + ), + error = TRUE + ) + + # RFFT operand trailing dims must equal fft_length (C4) + expect_snapshot( + infer_types_fft( + vt("f32", c(2L, 4L)), + "RFFT", + cnst(8L, "i64", 1L) + ), + error = TRUE + ) + + # IRFFT operand last dim must be fft_length / 2 + 1 (C4, C5) + expect_snapshot( + infer_types_fft( + vt("c64", c(2L, 4L)), + "IRFFT", + cnst(8L, "i64", 1L) + ), + error = TRUE + ) +})