Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@ Collate:
'op-erfc.R'
'op-exponential.R'
'op-exponential_minus_one.R'
'op-fft.R'
'op-floor.R'
'op-gather.R'
'op-if.R'
Expand Down
6 changes: 6 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ S3method(r_to_constant,logical)
S3method(repr,"NULL")
S3method(repr,BoolAttr)
S3method(repr,BooleanType)
S3method(repr,ComplexType)
S3method(repr,Constant)
S3method(repr,ConstantAttr)
S3method(repr,CustomOpBackendConfig)
Expand All @@ -84,6 +85,7 @@ S3method(repr,OpCompare)
S3method(repr,OpConstant)
S3method(repr,OpCustomCall)
S3method(repr,OpDotGeneral)
S3method(repr,OpFft)
S3method(repr,OpGather)
S3method(repr,OpInputAttrs)
S3method(repr,OpInputFunc)
Expand Down Expand Up @@ -132,6 +134,7 @@ export(.current_func)
export(.current_module)
export(BoolAttr)
export(BooleanType)
export(ComplexType)
export(Constant)
export(ConstantAttr)
export(CustomOpBackendConfig)
Expand Down Expand Up @@ -222,6 +225,7 @@ export(hlo_erf_inv)
export(hlo_erfc)
export(hlo_exponential)
export(hlo_exponential_minus_one)
export(hlo_fft)
export(hlo_floor)
export(hlo_func)
export(hlo_gather)
Expand Down Expand Up @@ -316,6 +320,7 @@ export(infer_types_erf_inv)
export(infer_types_erfc)
export(infer_types_exponential)
export(infer_types_exponential_minus_one)
export(infer_types_fft)
export(infer_types_float_biv)
export(infer_types_float_uni)
export(infer_types_floor)
Expand Down Expand Up @@ -391,6 +396,7 @@ importFrom(cli,format_error)
importFrom(methods,is)
importFrom(stats,setNames)
importFrom(tengen,BooleanType)
importFrom(tengen,ComplexType)
importFrom(tengen,FloatType)
importFrom(tengen,IntegerType)
importFrom(tengen,UIntegerType)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,10 @@

## Features

* Added `hlo_fft()` for forward and inverse Fourier transforms (`FFT`, `IFFT`,
`RFFT`, `IRFFT`).
* Re-export `ComplexType()` from tengen for representing `complex<f32>` /
`complex<f64>` (a.k.a. `c64` / `c128`) tensors.
* Added support for CHLO ops, a higher-level companion dialect to stableHLO
that is lowered to stableHLO during compilation. New ops:
* Inverse trig: `hlo_acos()`, `hlo_asin()`, `hlo_atan()`.
Expand Down
198 changes: 198 additions & 0 deletions R/op-fft.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
#' @include op.R hlo.R
NULL

OpFft <- new_Op("OpFft", "fft")

# Map fft_type to (operand_kind, result_kind) where each is "real" or "complex"
# Used by both inference and the public function for clarity.
.fft_type_kinds <- list(
FFT = list(operand = "complex", result = "complex"),
IFFT = list(operand = "complex", result = "complex"),
RFFT = list(operand = "real", result = "complex"),
IRFFT = list(operand = "complex", result = "real")
)

.fft_real_to_complex <- function(real_dtype) {
ComplexType(real_dtype$value)
}

.fft_complex_to_real <- function(complex_dtype) {
FloatType(complex_dtype$value)
}

#' @rdname hlo_fft
#' @export
infer_types_fft <- function(operand, fft_type, fft_length) {
assert_vt_is_tensor(operand)

if (!test_choice(fft_type, names(.fft_type_kinds))) {
cli_abort(c(
"{.arg fft_type} must be one of {.val {names(.fft_type_kinds)}}.",
x = "Got {.val {fft_type}}."
))
}
kinds <- .fft_type_kinds[[fft_type]]

assert_const(fft_length, dtype = IntegerType(64L), ndims = 1L)
fft_len <- as.integer(fft_length$data)
L <- length(fft_len)

# (C3) 1 <= size(fft_length) <= 3.
if (L < 1L || L > 3L) {
cli_abort(c(
"{.arg fft_length} must have 1 to 3 elements.",
x = "Got {.val {L}}."
))
}

operand_dtype <- operand$type$dtype
operand_shape <- shape(operand)
rank <- length(operand_shape)

# (C2) operand element type constraint.
if (kinds$operand == "real") {
if (!test_class(operand_dtype, "FloatType")) {
cli_abort(c(
"{.arg operand} must have a floating-point dtype for {.val {fft_type}}.",
x = "Got {.val {operand_dtype}}."
))
}
result_dtype <- .fft_real_to_complex(operand_dtype)
} else {
if (!test_class(operand_dtype, "ComplexType")) {
cli_abort(c(
"{.arg operand} must have a complex dtype for {.val {fft_type}}.",
x = "Got {.val {operand_dtype}}."
))
}
result_dtype <- if (kinds$result == "complex") {
operand_dtype
} else {
.fft_complex_to_real(operand_dtype)
}
}

# (C1) size(fft_length) <= rank(operand).
if (L > rank) {
cli_abort(c(
"{.arg fft_length} length must not exceed rank of {.arg operand}.",
x = "Got fft_length of length {L} and operand of rank {rank}."
))
}

fft_axes <- seq.int(rank - L + 1L, rank)

# (C4) For the tensor of floating-point type among operand/result,
# shape(real)[-L:] must equal fft_length.
if (kinds$operand == "real") {
real_tail <- operand_shape[fft_axes]
if (!identical(real_tail, fft_len)) {
cli_abort(c(
"{.arg operand} trailing {L} dimension{?s} must equal {.arg fft_length}.",
x = "Got operand shape {shapevec_repr(operand_shape)} and fft_length = {vec_repr(fft_len)}."
))
}
}

# (C5) shape(result) = shape(operand) except for special cases.
result_shape <- operand_shape
if (fft_type == "RFFT") {
last <- operand_shape[[rank]]
result_shape[[rank]] <- if (last == 0L) 0L else as.integer(last %/% 2L + 1L)
} else if (fft_type == "IRFFT") {
# (C4) shape(result)[-L:] must equal fft_length except for the last dim,
# where dim(operand, -1) = fft_length[L] / 2 + 1.
expected_last_operand <- if (fft_len[[L]] == 0L) {
0L
} else {
as.integer(fft_len[[L]] %/% 2L + 1L)
}
operand_tail_lead <- operand_shape[fft_axes[seq_len(L - 1L)]]
fft_lead <- fft_len[seq_len(L - 1L)]
if (
!identical(operand_tail_lead, fft_lead) ||
operand_shape[[rank]] != expected_last_operand
) {
cli_abort(c(
"{.arg operand} trailing dimensions must match {.arg fft_length}.",
x = paste0(
"Got operand shape {shapevec_repr(operand_shape)} and fft_length = ",
"{vec_repr(fft_len)}; ",
"expected last operand dim {.val {expected_last_operand}} ",
"(= fft_length[{L}] / 2 + 1)."
)
))
}
result_shape[fft_axes] <- fft_len
}

ValueTypes(list(
ValueType(
TensorType(
dtype = result_dtype,
shape = Shape(result_shape)
)
)
))
}

hlo_fft_impl <- hlo_fn(OpFft, infer_types_fft)

#' @templateVar mnemonic fft
#' @template op
#' @param operand (`FuncValue`)\cr
#' Tensor of floating-point or complex type.
#' @param fft_type (`character(1)`)\cr
#' One of `"FFT"`, `"IFFT"`, `"RFFT"`, `"IRFFT"`.
#' @param fft_length (`integer()`)\cr
#' Length 1, 2, or 3 vector of `i64` values giving the FFT lengths along
#' the trailing dimensions.
#' @export
hlo_fft <- function(
operand,
fft_type = c("FFT", "IFFT", "RFFT", "IRFFT"),
fft_length
) {
fft_type <- match.arg(fft_type)
fft_length <- as.integer(fft_length)

hlo_fft_impl(
values = list(operand = operand),
attrs = list(
constant_attr(
"fft_length",
fft_length,
dtype = "i64",
shape = length(fft_length)
)
),
custom_attrs = list(fft_type = fft_type)
)
}

#' @export
repr.OpFft <- function(x, simplify_dense = TRUE, ...) {
attrs_repr <- vapply(
x$inputs$attrs,
function(a) repr(a, simplify_dense = simplify_dense),
character(1)
)
fft_type_repr <- paste0(
"fft_type = #stablehlo<fft_type ",
x$inputs$custom_attrs$fft_type,
">"
)
all_attrs <- paste(c(fft_type_repr, attrs_repr), collapse = ",\n")

paste0(
repr(x$outputs),
" = ",
repr(x$name),
" (",
repr(x$inputs$values),
") {\n",
all_attrs,
"\n}: ",
repr(x$signature)
)
}
8 changes: 8 additions & 0 deletions R/types.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@ tengen::UIntegerType
#' @export
tengen::FloatType

#' @export
tengen::ComplexType

#' @export
tengen::is_dtype

Expand All @@ -39,6 +42,11 @@ repr.FloatType <- function(x, ...) {
as.character(x)
}

#' @export
repr.ComplexType <- function(x, ...) {
paste0("complex<f", x$value, ">")
}

# Re-export assert_dtype from tengen
assert_dtype <- tengen::assert_dtype

Expand Down
30 changes: 30 additions & 0 deletions man/hlo_fft.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

3 changes: 2 additions & 1 deletion man/reexports.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading