Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,13 @@ jobs:
- ubuntu-latest
arch:
- x64
include:
- os: macos-latest
arch: arm64
version: '1'
- os: windows-latest
arch: x64
version: '1'
steps:
- uses: actions/checkout@v6
- uses: julia-actions/setup-julia@v2
Expand Down
176 changes: 95 additions & 81 deletions src/algos.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,6 @@
Int(d)
end

@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w))

function fft!(
out::AbstractVector{T}, in::AbstractVector{T},
start_out::Int, start_in::Int,
Expand All @@ -19,15 +17,14 @@ function fft!(
s_in = root.s_in
s_out = root.s_out
N = root.sz
w = _conj(root.w, d)
if t === DFT
fft_dft!(out, in, N, start_out, s_out, start_in, s_in, w)
fft_dft!(out, in, N, start_out, s_out, start_in, s_in, d)
elseif t === POW2RADIX4_FFT
fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, w)
fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, d)
elseif t === POW3_FFT
_m_120 = cispi(T(2) / 3)
m_120 = d === FFT_FORWARD ? _m_120 : conj(_m_120)
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, w, m_120)
fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, m_120, d)
elseif t === BLUESTEIN
fft_bluestein!(out, in, d, N, start_out, s_out, start_in, s_in)
else
Expand Down Expand Up @@ -66,12 +63,8 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out
Rt = right.type
Lt = left.type

w1 = _conj(root.w, d)
Rtype = real(T)
# The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`.
# Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2
# for fixed j1; (α, β) depend on j1 so we reset them at each outer step.
dir = twiddle_direction(w1)
dir = direction_sign(d)
tmp = g.workspace[idx]

if Rt === BLUESTEIN
Expand All @@ -90,11 +83,14 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out
end

if j1 > 0
αi, βi = singleton_params(dir * Rtype(2 * j1) / Rtype(N))
ci, si = one(Rtype), zero(Rtype)
# The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`.
# Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2
# for fixed j1; (α, β) depend on j1 so we reset them at each outer step.
zj1 = singleton_params(dir * Rtype(j1) / Rtype(N))
wk2 = one(T)
@inbounds for k2 in 1:N2-1
ci, si = singleton_step(ci, si, αi, βi)
tmp[R_start_out + k2] *= Complex(ci, si)
wk2 = singleton_step(wk2, zj1)
tmp[R_start_out + k2] *= wk2
end
end
end
Expand Down Expand Up @@ -130,28 +126,40 @@ Discrete Fourier Transform, O(N^2) algorithm, in place.
- `w`: The value `cispi(direction_sign(d) * 2 / N)`

"""
function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T}
function fft_dft!(
out::AbstractVector{T}, in::AbstractVector{T},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
d::Direction
) where {T<:Complex}
tmp = in[start_in]
@inbounds for j in 1:N-1
tmp += in[start_in + j*stride_in]
end
out[start_out] = tmp

Rtype = real(T)
dir = twiddle_direction(w)
@inbounds for d in 1:N-1
t = in[start_in]
αk, βk = singleton_params(dir * Rtype(2 * d) / Rtype(N))
ck, sk = one(Rtype), zero(Rtype)
dir = direction_sign(d)
@inbounds for j in 1:N-1
tmp = in[start_in]
zj = singleton_params(dir * Rtype(j) / Rtype(N))
wk = one(T)
@inbounds for k in 1:N-1
ck, sk = singleton_step(ck, sk, αk, βk)
t += Complex(ck, sk) * in[start_in + k*stride_in]
wk = singleton_step(wk, zj)
tmp += wk * in[start_in + k*stride_in]
end
out[start_out + d*stride_out] = t
out[start_out + j*stride_out] = tmp
end
end

function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::Complex{T}) where {T<:Real}
function fft_dft!(
out::AbstractVector{Complex{T}}, in::AbstractVector{T},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
d::Direction
) where {T<:Real}
halfN = N÷2

tmp = Complex{T}(in[start_in])
Expand All @@ -160,16 +168,16 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int
end
out[start_out] = tmp

dir = twiddle_direction(w)
@inbounds for d in 1:halfN
t = Complex{T}(in[start_in])
αk, βk = singleton_params(dir * T(2 * d) / T(N))
ck, sk = one(T), zero(T)
dir = direction_sign(d)
@inbounds for j in 1:halfN
tmp = Complex{T}(in[start_in])
zj = singleton_params(dir * T(j) / T(N))
wk = one(complex(T))
@inbounds for k in 1:N-1
ck, sk = singleton_step(ck, sk, αk, βk)
t += Complex{T}(ck, sk) * in[start_in + k*stride_in]
wk = singleton_step(wk, zj)
tmp += wk * in[start_in + k*stride_in]
end
out[start_out + d*stride_out] = t
out[start_out + j*stride_out] = tmp
end
end

Expand All @@ -189,16 +197,24 @@ Radix-4 FFT for powers of 2, in place
- `w`: The value `cispi(direction_sign(d) * 2 / N)`

"""
function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U}
function fft_pow2_radix4!(
out::AbstractVector{T}, in::AbstractVector{U},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
d::Direction
) where {T<:Complex, U}
# If N is 2, compute the size two DFT
@inbounds if N == 2
out[start_out] = in[start_in] + in[start_in + stride_in]
out[start_out + stride_out] = in[start_in] - in[start_in + stride_in]
return
end

dir = direction_sign(d)

# If N is 4, compute an unrolled radix-2 FFT and return
minusi = -sign(imag(w)) * im
minusi = -dir * im
@inbounds if N == 4
xee = in[start_in]
xoe = in[start_in + stride_in]
Expand All @@ -218,34 +234,30 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int,
# ...othersize split the problem in four and recur
m = N ÷ 4

Rtype = real(T)
dir = twiddle_direction(w)
# Recursive sub-problem step `cispi(dir · 2 / m) = w^4`; use `cispi`
# directly so the sub-tree gets a < 1 ULP starting phase.
w_sub = cispi(dir * Rtype(2) / Rtype(m))

fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w_sub)
fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, d)
fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, d)
fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, d)
fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, d)

Rtype = real(T)
# Singleton recurrence for the three running twiddles `w^k`, `w^2k`, `w^3k`.
α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N))
α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N))
α3, β3 = singleton_params(dir * Rtype(6) / Rtype(N))
c1, s1 = one(Rtype), zero(Rtype)
c2, s2 = one(Rtype), zero(Rtype)
c3, s3 = one(Rtype), zero(Rtype)
z1 = singleton_params(dir * Rtype(1) / Rtype(N))
z2 = singleton_params(dir * Rtype(2) / Rtype(N))
z3 = singleton_params(dir * Rtype(3) / Rtype(N))

wkoe = one(T)
wkeo = one(T)
wkoo = one(T)

@inbounds for k in 0:m-1
kee = start_out + k * stride_out
koe = start_out + (k + m) * stride_out
keo = start_out + (k + 2 * m) * stride_out
koo = start_out + (k + 3 * m) * stride_out
y_kee, y_koe, y_keo, y_koo = out[kee], out[koe], out[keo], out[koo]
t_keo = y_keo * Complex(c2, s2)
t_koe = y_koe * Complex(c1, s1)
t_koo = y_koo * Complex(c3, s3)
t_koe = y_koe * wkoe
t_keo = y_keo * wkeo
t_koo = y_koo * wkoo
y_kee_p_y_keo = y_kee + t_keo
y_kee_m_y_keo = y_kee - t_keo
t_koe_p_t_koo = t_koe + t_koo
Expand All @@ -254,9 +266,9 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int,
out[koe] = y_kee_m_y_keo + t_koe_m_t_koo
out[keo] = y_kee_p_y_keo - t_koe_p_t_koo
out[koo] = y_kee_m_y_keo - t_koe_m_t_koo
c1, s1 = singleton_step(c1, s1, α1, β1)
c2, s2 = singleton_step(c2, s2, α2, β2)
c3, s3 = singleton_step(c3, s3, α3, β3)
wkoe = singleton_step(wkoe, z1)
wkeo = singleton_step(wkeo, z2)
wkoo = singleton_step(wkoo, z3)
end
end

Expand All @@ -278,7 +290,14 @@ Power of 3 FFT, in place
- `minus120`: Depending on direction, perform either ∓120° rotation

"""
function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, minus120::T) where {T, U}
function fft_pow3!(
out::AbstractVector{T}, in::AbstractVector{U},
N::Int,
start_out::Int, stride_out::Int,
start_in::Int, stride_in::Int,
minus120::T,
d::Direction
) where {T, U}
plus120 = conj(minus120)
if N == 3
@muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in]
Expand All @@ -290,32 +309,28 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_
# Size of subproblem
Nprime = N ÷ 3

# Dividing into subproblems
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, minus120, d)
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, minus120, d)
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, minus120, d)

Rtype = real(T)
dir = twiddle_direction(w)
# Recursive sub-problem step cispi(dir · 2 / Nprime) = w^3.
w_sub = cispi(dir * Rtype(2) / Rtype(Nprime))
dir = direction_sign(d)

# Dividing into subproblems
fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w_sub, minus120)
fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w_sub, minus120)
fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w_sub, minus120)

α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N))
α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N))
c1, s1 = one(Rtype), zero(Rtype)
c2, s2 = one(Rtype), zero(Rtype)
z1 = singleton_params(dir * Rtype(1) / Rtype(N))
z2 = singleton_params(dir * Rtype(2) / Rtype(N))
wk1 = one(T)
wk2 = one(T)
for k in 0:Nprime-1
k0 = start_out + stride_out * k
k1 = start_out + stride_out * (k + Nprime)
k2 = start_out + stride_out * (k + 2 * Nprime)
y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2]
wk1 = Complex(c1, s1)
wk2 = Complex(c2, s2)
@muladd out[k0] = y_k0 + y_k1*wk1 + y_k2*wk2
@muladd out[k1] = y_k0 + y_k1*wk1*plus120 + y_k2*wk2*minus120
@muladd out[k2] = y_k0 + y_k1*wk1*minus120 + y_k2*wk2*plus120
c1, s1 = singleton_step(c1, s1, α1, β1)
c2, s2 = singleton_step(c2, s2, α2, β2)
@muladd out[k0] = y_k0 + y_k1 * wk1 + y_k2 * wk2
@muladd out[k1] = y_k0 + y_k1 * wk1 * plus120 + y_k2 * wk2 * minus120
@muladd out[k2] = y_k0 + y_k1 * wk1 * minus120 + y_k2 * wk2 * plus120
wk1 = singleton_step(wk1, z1)
wk2 = singleton_step(wk2, z2)
end
end

Expand Down Expand Up @@ -383,14 +398,13 @@ function fft_bluestein!(
a_series[i] = in[start_in+(i-1)*stride_in] * conj(b_series[i])
end

w_pad = cispi(T(2) / pad_len)
# leave b_n vector alone for last step
fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, w_pad) # Fa
fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, w_pad) # Fb
fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fa
fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fb

tmp .*= a_series
# convolution theorem ifft
fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, conj(w_pad))
fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, FFT_FORWARD)
conv_a_b = a_series

Xk = tmp
Expand Down
19 changes: 12 additions & 7 deletions src/plan.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,14 +402,14 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R
n = p.flen
if iseven(n)
# For problems of even size, we solve the rfft problem by splitting the
# problem into the even and odd part and solving the simultanously as
# problem into the even and odd part and solving them simultaneously as
# a single (complex) fft of half the size, see equations (6)-(8) of
# Sorensen, H. V., D. Jones, Michael Heideman, and C. Burrus.
# "Real-valued fast Fourier transform algorithms."
# IEEE Transactions on acoustics, speech, and signal processing 35, no. 6 (2003): 849-863.
if x isa Vector && isbitstype(T)
# For a vector of bits, we can just reintepret the bits to get the
# approciate representation of even (zero based) elements as the real
# For a vector of bits, we can just reinterpret the bits to get the
# appropriate representation of even (zero based) elements as the real
# part and the odd as the complex part
x_c = reinterpret(Complex{T}, x)
else
Expand All @@ -425,7 +425,8 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R

# The w stored in the plan is for m, not n, so probably cheapest to
# just recompute it instead of taking a square root
wj = w = cispi(-T(2) / n)
z1 = singleton_params(-one(T) / n)
wj = cispi(-T(2) / n)

# Construct the result by first constructing the elements of the
# real and imaginary part, followed by the usual radix-2 assembly,
Expand All @@ -441,7 +442,7 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R
XY = T(0.5) * (-yj + conj(ymj)) * im
y[j] = XX + wj * XY
y[m-j+2] = conj(XX - wj * XY)
wj *= w
wj = singleton_step(wj, z1)
end
return y
else
Expand All @@ -467,7 +468,11 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex}
# See explanation of this approach in the method for the FORWARD transform
if iseven(n)
m = n >> 1
wj = w = cispi(T(2) / n)

R = real(T)
z1 = singleton_params(one(R) / n)
wj = cispi(R(2) / n)

x_tmp = similar(x, length(x) - 1)
x_tmp[1] = complex(
(real(x[1]) + real(x[end])),
Expand All @@ -478,7 +483,7 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex}
XY = wj * (x[j] - conj(x[m-j+2]))
x_tmp[j] = XX + im * XY
x_tmp[m-j+2] = conj(XX - im * XY)
wj *= w
wj = singleton_step(wj, z1)
end
y_c = complex(p) * x_tmp
if isbitstype(T)
Expand Down
Loading
Loading