From dfd02467b323edacd05eb4578ce31d59d7e96724 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Fri, 15 May 2026 16:18:17 +0800 Subject: [PATCH 1/6] cleanup: singleton --- src/algos.jl | 86 ++++++++++++++++++++-------------------- src/singleton_twiddle.jl | 17 ++++---- test/onedim/accuracy.jl | 6 +-- 3 files changed, 52 insertions(+), 57 deletions(-) diff --git a/src/algos.jl b/src/algos.jl index 1ab65a4..6ab1d84 100644 --- a/src/algos.jl +++ b/src/algos.jl @@ -66,12 +66,8 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out Rt = right.type Lt = left.type - w1 = _conj(root.w, d) Rtype = real(T) - # The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`. - # Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2 - # for fixed j1; (α, β) depend on j1 so we reset them at each outer step. - dir = twiddle_direction(w1) + dir = Int(d) tmp = g.workspace[idx] if Rt === BLUESTEIN @@ -90,11 +86,14 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out end if j1 > 0 - αi, βi = singleton_params(dir * Rtype(2 * j1) / Rtype(N)) - ci, si = one(Rtype), zero(Rtype) + # The composite twiddle at position (j1, k2) is `cispi(dir · 2 j1 k2 / N)`. + # Singleton's recurrence advances `wk2 = cispi(dir · 2 j1 k2 / N)` in k2 + # for fixed j1; (α, β) depend on j1 so we reset them at each outer step. + zj1 = singleton_params(dir * Rtype(j1) / Rtype(N)) + wk2 = one(T) @inbounds for k2 in 1:N2-1 - ci, si = singleton_step(ci, si, αi, βi) - tmp[R_start_out + k2] *= Complex(ci, si) + wk2 = singleton_step(wk2, zj1) + tmp[R_start_out + k2] *= wk2 end end end @@ -140,14 +139,14 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_o Rtype = real(T) dir = twiddle_direction(w) @inbounds for d in 1:N-1 - t = in[start_in] - αk, βk = singleton_params(dir * Rtype(2 * d) / Rtype(N)) - ck, sk = one(Rtype), zero(Rtype) + tmp = in[start_in] + zd = singleton_params(dir * Rtype(d) / Rtype(N)) + wk = one(T) @inbounds for k in 1:N-1 - ck, sk = singleton_step(ck, sk, αk, βk) - t += Complex(ck, sk) * in[start_in + k*stride_in] + wk = singleton_step(wk, zd) + tmp += wk * in[start_in + k*stride_in] end - out[start_out + d*stride_out] = t + out[start_out + d*stride_out] = tmp end end @@ -162,14 +161,14 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int dir = twiddle_direction(w) @inbounds for d in 1:halfN - t = Complex{T}(in[start_in]) - αk, βk = singleton_params(dir * T(2 * d) / T(N)) - ck, sk = one(T), zero(T) + tmp = Complex{T}(in[start_in]) + zd = singleton_params(dir * T(d) / T(N)) + wk = one(complex(T)) @inbounds for k in 1:N-1 - ck, sk = singleton_step(ck, sk, αk, βk) - t += Complex{T}(ck, sk) * in[start_in + k*stride_in] + wk = singleton_step(wk, zd) + tmp += wk * in[start_in + k*stride_in] end - out[start_out + d*stride_out] = t + out[start_out + d*stride_out] = tmp end end @@ -230,12 +229,13 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w_sub) # Singleton recurrence for the three running twiddles `w^k`, `w^2k`, `w^3k`. - α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N)) - α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N)) - α3, β3 = singleton_params(dir * Rtype(6) / Rtype(N)) - c1, s1 = one(Rtype), zero(Rtype) - c2, s2 = one(Rtype), zero(Rtype) - c3, s3 = one(Rtype), zero(Rtype) + z1 = singleton_params(dir * Rtype(1) / Rtype(N)) + z2 = singleton_params(dir * Rtype(2) / Rtype(N)) + z3 = singleton_params(dir * Rtype(3) / Rtype(N)) + + wkoe = one(T) + wkeo = one(T) + wkoo = one(T) @inbounds for k in 0:m-1 kee = start_out + k * stride_out @@ -243,9 +243,9 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, keo = start_out + (k + 2 * m) * stride_out koo = start_out + (k + 3 * m) * stride_out y_kee, y_koe, y_keo, y_koo = out[kee], out[koe], out[keo], out[koo] - t_keo = y_keo * Complex(c2, s2) - t_koe = y_koe * Complex(c1, s1) - t_koo = y_koo * Complex(c3, s3) + t_koe = y_koe * wkoe + t_keo = y_keo * wkeo + t_koo = y_koo * wkoo y_kee_p_y_keo = y_kee + t_keo y_kee_m_y_keo = y_kee - t_keo t_koe_p_t_koo = t_koe + t_koo @@ -254,9 +254,9 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, out[koe] = y_kee_m_y_keo + t_koe_m_t_koo out[keo] = y_kee_p_y_keo - t_koe_p_t_koo out[koo] = y_kee_m_y_keo - t_koe_m_t_koo - c1, s1 = singleton_step(c1, s1, α1, β1) - c2, s2 = singleton_step(c2, s2, α2, β2) - c3, s3 = singleton_step(c3, s3, α3, β3) + wkoe = singleton_step(wkoe, z1) + wkeo = singleton_step(wkeo, z2) + wkoo = singleton_step(wkoo, z3) end end @@ -300,22 +300,20 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_ fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w_sub, minus120) fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w_sub, minus120) - α1, β1 = singleton_params(dir * Rtype(2) / Rtype(N)) - α2, β2 = singleton_params(dir * Rtype(4) / Rtype(N)) - c1, s1 = one(Rtype), zero(Rtype) - c2, s2 = one(Rtype), zero(Rtype) + z1 = singleton_params(dir * Rtype(1) / Rtype(N)) + z2 = singleton_params(dir * Rtype(2) / Rtype(N)) + wk1 = one(T) + wk2 = one(T) for k in 0:Nprime-1 k0 = start_out + stride_out * k k1 = start_out + stride_out * (k + Nprime) k2 = start_out + stride_out * (k + 2 * Nprime) y_k0, y_k1, y_k2 = out[k0], out[k1], out[k2] - wk1 = Complex(c1, s1) - wk2 = Complex(c2, s2) - @muladd out[k0] = y_k0 + y_k1*wk1 + y_k2*wk2 - @muladd out[k1] = y_k0 + y_k1*wk1*plus120 + y_k2*wk2*minus120 - @muladd out[k2] = y_k0 + y_k1*wk1*minus120 + y_k2*wk2*plus120 - c1, s1 = singleton_step(c1, s1, α1, β1) - c2, s2 = singleton_step(c2, s2, α2, β2) + @muladd out[k0] = y_k0 + y_k1 * wk1 + y_k2 * wk2 + @muladd out[k1] = y_k0 + y_k1 * wk1 * plus120 + y_k2 * wk2 * minus120 + @muladd out[k2] = y_k0 + y_k1 * wk1 * minus120 + y_k2 * wk2 * plus120 + wk1 = singleton_step(wk1, z1) + wk2 = singleton_step(wk2, z2) end end diff --git a/src/singleton_twiddle.jl b/src/singleton_twiddle.jl index 416ac63..7b9c3ca 100644 --- a/src/singleton_twiddle.jl +++ b/src/singleton_twiddle.jl @@ -23,26 +23,25 @@ # to the same twiddle set so we pick +1. @inline function twiddle_direction(w::Complex{T}) where {T<:Real} s = imag(w) - s > 0 ? one(T) : (s < 0 ? -one(T) : one(T)) + copysign(one(T), s) end # Recurrence coefficients for stepping by `cispi(freq) = e^(iπ·freq)`. -# Uses `sincospi(freq/2)` so that `α` and `β` are exact-to-ULP even +# Uses `sincospi(hfreq)` so that `α` and `β` are exact-to-ULP even # for very small frequencies — writing `1 - cos(θ)` directly suffers # catastrophic cancellation there. -@inline function singleton_params(freq::T) where {T<:Real} - s_h, c_h = sincospi(freq / 2) +@inline function singleton_params(hfreq::Real) + s_h, c_h = sincospi(-hfreq) α = 2 * s_h * s_h β = 2 * s_h * c_h - (α, β) + Complex(α, β) end # Advance `(c, s) = (cos(kθ), sin(kθ))` to `(cos((k+1)θ), sin((k+1)θ))`. # Computed as `c - (αc + βs)` rather than `(1-α)c - βs` on purpose: # the correction is small so subtracting it from `c` preserves the # high-order bits and the recurrence self-heals. -@inline function singleton_step(c::T, s::T, α::T, β::T) where {T<:Real} - c_new = c - muladd(α, c, β * s) - s_new = s - muladd(α, s, -(β * c)) - (c_new, s_new) +@inline function singleton_step(w::T, z::T) where {T<:Complex} + # muladd only reduces instructions, doesn't help precision much + w - @fastmath(z * w) end diff --git a/test/onedim/accuracy.jl b/test/onedim/accuracy.jl index 7a15165..6bd261e 100644 --- a/test/onedim/accuracy.jl +++ b/test/onedim/accuracy.jl @@ -8,8 +8,6 @@ using FFTA, Test, Random, LinearAlgebra # still fail comfortably against the pre-fix naive `w *= step` # recurrence, which ballooned past ~4000 ULP at N = 16384. -Random.seed!(42) - # (N, max eps ratio) across the power-of-2 ladder. Covers both even # powers (= powers of 4, recursion bottoms at N = 4) and odd powers # (recursion bottoms at N = 2), which hit different base cases in @@ -47,8 +45,8 @@ const POWERS_OF_3 = ( function _worst_relerr(N::Int) worst = 0.0 for seed in 1:5 - Random.seed!(seed) - x64 = randn(ComplexF64, N) + rng = Xoshiro(seed) + x64 = randn(rng, ComplexF64, N) x32 = ComplexF32.(x64) y32 = fft(x32) y_ref = ComplexF32.(fft(x64)) From baca4e16264db021ae02739dfa70ce9986a857b9 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Fri, 15 May 2026 16:18:17 +0800 Subject: [PATCH 2/6] better accuracy tests --- test/onedim/accuracy.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/test/onedim/accuracy.jl b/test/onedim/accuracy.jl index 6bd261e..b4f357f 100644 --- a/test/onedim/accuracy.jl +++ b/test/onedim/accuracy.jl @@ -45,12 +45,12 @@ const POWERS_OF_3 = ( function _worst_relerr(N::Int) worst = 0.0 for seed in 1:5 - rng = Xoshiro(seed) - x64 = randn(rng, ComplexF64, N) - x32 = ComplexF32.(x64) + rng = @isdefined(Xoshiro) ? Xoshiro(seed) : MersenneTwister(seed) + x32 = randn(rng, ComplexF32, N) + x64 = ComplexF64.(x32) y32 = fft(x32) y_ref = ComplexF32.(fft(x64)) - relerr = norm(y32 .- y_ref) / norm(y_ref) + relerr = norm(y32 - y_ref) / norm(y_ref) worst = max(worst, relerr / eps(Float32)) end return worst From d88b7de12e0738a91673e39c37520070255411e6 Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Fri, 15 May 2026 16:18:17 +0800 Subject: [PATCH 3/6] use new recurrence in plan.jl --- src/plan.jl | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index ef8e979..423a65e 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -425,7 +425,8 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R # The w stored in the plan is for m, not n, so probably cheapest to # just recompute it instead of taking a square root - wj = w = cispi(-T(2) / n) + z1 = singleton_params(-one(T) / n) + wj = cispi(-T(2) / n) # Construct the result by first constructing the elements of the # real and imaginary part, followed by the usual radix-2 assembly, @@ -441,7 +442,7 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R XY = T(0.5) * (-yj + conj(ymj)) * im y[j] = XX + wj * XY y[m-j+2] = conj(XX - wj * XY) - wj *= w + wj = singleton_step(wj, z1) end return y else @@ -467,7 +468,11 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex} # See explanation of this approach in the method for the FORWARD transform if iseven(n) m = n >> 1 - wj = w = cispi(T(2) / n) + + R = real(T) + z1 = singleton_params(one(R) / n) + wj = cispi(R(2) / n) + x_tmp = similar(x, length(x) - 1) x_tmp[1] = complex( (real(x[1]) + real(x[end])), @@ -478,7 +483,7 @@ function Base.:*(p::FFTAPlan_re{T,1}, x::AbstractVector{T}) where {T<:Complex} XY = wj * (x[j] - conj(x[m-j+2])) x_tmp[j] = XX + im * XY x_tmp[m-j+2] = conj(XX - im * XY) - wj *= w + wj = singleton_step(wj, z1) end y_c = complex(p) * x_tmp if isbitstype(T) From afb344861a104bd532a63bbb3b4b2cb303a95c6d Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Fri, 15 May 2026 16:18:17 +0800 Subject: [PATCH 4/6] include macos-latest, windows-latest --- .github/workflows/ci.yml | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 0527323..6cfe048 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -22,6 +22,13 @@ jobs: - ubuntu-latest arch: - x64 + include: + - os: macos-latest + arch: arm64 + version: '1' + - os: windows-latest + arch: x64 + version: '1' steps: - uses: actions/checkout@v6 - uses: julia-actions/setup-julia@v2 From a3035d3dace7c407a3d9808530d4b3cff0592fea Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Fri, 15 May 2026 16:18:17 +0800 Subject: [PATCH 5/6] fix typos --- src/plan.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/plan.jl b/src/plan.jl index 423a65e..0928680 100644 --- a/src/plan.jl +++ b/src/plan.jl @@ -402,14 +402,14 @@ function Base.:*(p::FFTAPlan_re{Complex{T},1}, x::AbstractVector{T}) where {T<:R n = p.flen if iseven(n) # For problems of even size, we solve the rfft problem by splitting the - # problem into the even and odd part and solving the simultanously as + # problem into the even and odd part and solving them simultaneously as # a single (complex) fft of half the size, see equations (6)-(8) of # Sorensen, H. V., D. Jones, Michael Heideman, and C. Burrus. # "Real-valued fast Fourier transform algorithms." # IEEE Transactions on acoustics, speech, and signal processing 35, no. 6 (2003): 849-863. if x isa Vector && isbitstype(T) - # For a vector of bits, we can just reintepret the bits to get the - # approciate representation of even (zero based) elements as the real + # For a vector of bits, we can just reinterpret the bits to get the + # appropriate representation of even (zero based) elements as the real # part and the odd as the complex part x_c = reinterpret(Complex{T}, x) else From 367947f91a30361e988f26926320b7d92f99cf2d Mon Sep 17 00:00:00 2001 From: wheeheee <104880306+wheeheee@users.noreply.github.com> Date: Fri, 15 May 2026 16:18:17 +0800 Subject: [PATCH 6/6] fix types, pass Direction directly --- src/algos.jl | 104 +++++++++++++++++++++--------------- src/singleton_twiddle.jl | 8 --- test/onedim/real_forward.jl | 2 +- 3 files changed, 61 insertions(+), 53 deletions(-) diff --git a/src/algos.jl b/src/algos.jl index 6ab1d84..d272345 100644 --- a/src/algos.jl +++ b/src/algos.jl @@ -2,8 +2,6 @@ Int(d) end -@inline _conj(w::Complex, d::Direction) = ifelse(direction_sign(d) === 1, w, conj(w)) - function fft!( out::AbstractVector{T}, in::AbstractVector{T}, start_out::Int, start_in::Int, @@ -19,15 +17,14 @@ function fft!( s_in = root.s_in s_out = root.s_out N = root.sz - w = _conj(root.w, d) if t === DFT - fft_dft!(out, in, N, start_out, s_out, start_in, s_in, w) + fft_dft!(out, in, N, start_out, s_out, start_in, s_in, d) elseif t === POW2RADIX4_FFT - fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, w) + fft_pow2_radix4!(out, in, N, start_out, s_out, start_in, s_in, d) elseif t === POW3_FFT _m_120 = cispi(T(2) / 3) m_120 = d === FFT_FORWARD ? _m_120 : conj(_m_120) - fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, w, m_120) + fft_pow3!(out, in, N, start_out, s_out, start_in, s_in, m_120, d) elseif t === BLUESTEIN fft_bluestein!(out, in, d, N, start_out, s_out, start_in, s_in) else @@ -67,7 +64,7 @@ function fft_composite!(out::AbstractVector{T}, in::AbstractVector{U}, start_out Lt = left.type Rtype = real(T) - dir = Int(d) + dir = direction_sign(d) tmp = g.workspace[idx] if Rt === BLUESTEIN @@ -129,7 +126,13 @@ Discrete Fourier Transform, O(N^2) algorithm, in place. - `w`: The value `cispi(direction_sign(d) * 2 / N)` """ -function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T} +function fft_dft!( + out::AbstractVector{T}, in::AbstractVector{T}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + d::Direction +) where {T<:Complex} tmp = in[start_in] @inbounds for j in 1:N-1 tmp += in[start_in + j*stride_in] @@ -137,20 +140,26 @@ function fft_dft!(out::AbstractVector{T}, in::AbstractVector{T}, N::Int, start_o out[start_out] = tmp Rtype = real(T) - dir = twiddle_direction(w) - @inbounds for d in 1:N-1 + dir = direction_sign(d) + @inbounds for j in 1:N-1 tmp = in[start_in] - zd = singleton_params(dir * Rtype(d) / Rtype(N)) + zj = singleton_params(dir * Rtype(j) / Rtype(N)) wk = one(T) @inbounds for k in 1:N-1 - wk = singleton_step(wk, zd) + wk = singleton_step(wk, zj) tmp += wk * in[start_in + k*stride_in] end - out[start_out + d*stride_out] = tmp + out[start_out + j*stride_out] = tmp end end -function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::Complex{T}) where {T<:Real} +function fft_dft!( + out::AbstractVector{Complex{T}}, in::AbstractVector{T}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + d::Direction +) where {T<:Real} halfN = N÷2 tmp = Complex{T}(in[start_in]) @@ -159,16 +168,16 @@ function fft_dft!(out::AbstractVector{Complex{T}}, in::AbstractVector{T}, N::Int end out[start_out] = tmp - dir = twiddle_direction(w) - @inbounds for d in 1:halfN + dir = direction_sign(d) + @inbounds for j in 1:halfN tmp = Complex{T}(in[start_in]) - zd = singleton_params(dir * T(d) / T(N)) + zj = singleton_params(dir * T(j) / T(N)) wk = one(complex(T)) @inbounds for k in 1:N-1 - wk = singleton_step(wk, zd) + wk = singleton_step(wk, zj) tmp += wk * in[start_in + k*stride_in] end - out[start_out + d*stride_out] = tmp + out[start_out + j*stride_out] = tmp end end @@ -188,7 +197,13 @@ Radix-4 FFT for powers of 2, in place - `w`: The value `cispi(direction_sign(d) * 2 / N)` """ -function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T) where {T, U} +function fft_pow2_radix4!( + out::AbstractVector{T}, in::AbstractVector{U}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + d::Direction +) where {T<:Complex, U} # If N is 2, compute the size two DFT @inbounds if N == 2 out[start_out] = in[start_in] + in[start_in + stride_in] @@ -196,8 +211,10 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, return end + dir = direction_sign(d) + # If N is 4, compute an unrolled radix-2 FFT and return - minusi = -sign(imag(w)) * im + minusi = -dir * im @inbounds if N == 4 xee = in[start_in] xoe = in[start_in + stride_in] @@ -217,17 +234,12 @@ function fft_pow2_radix4!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, # ...othersize split the problem in four and recur m = N ÷ 4 - Rtype = real(T) - dir = twiddle_direction(w) - # Recursive sub-problem step `cispi(dir · 2 / m) = w^4`; use `cispi` - # directly so the sub-tree gets a < 1 ULP starting phase. - w_sub = cispi(dir * Rtype(2) / Rtype(m)) - - fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, w_sub) - fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, w_sub) - fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, w_sub) - fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, w_sub) + fft_pow2_radix4!(out, in, m, start_out , stride_out, start_in , stride_in*4, d) + fft_pow2_radix4!(out, in, m, start_out + m*stride_out, stride_out, start_in + stride_in, stride_in*4, d) + fft_pow2_radix4!(out, in, m, start_out + 2*m*stride_out, stride_out, start_in + 2*stride_in, stride_in*4, d) + fft_pow2_radix4!(out, in, m, start_out + 3*m*stride_out, stride_out, start_in + 3*stride_in, stride_in*4, d) + Rtype = real(T) # Singleton recurrence for the three running twiddles `w^k`, `w^2k`, `w^3k`. z1 = singleton_params(dir * Rtype(1) / Rtype(N)) z2 = singleton_params(dir * Rtype(2) / Rtype(N)) @@ -278,7 +290,14 @@ Power of 3 FFT, in place - `minus120`: Depending on direction, perform either ∓120° rotation """ -function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_out::Int, stride_out::Int, start_in::Int, stride_in::Int, w::T, minus120::T) where {T, U} +function fft_pow3!( + out::AbstractVector{T}, in::AbstractVector{U}, + N::Int, + start_out::Int, stride_out::Int, + start_in::Int, stride_in::Int, + minus120::T, + d::Direction +) where {T, U} plus120 = conj(minus120) if N == 3 @muladd out[start_out + 0] = in[start_in] + in[start_in + stride_in] + in[start_in + 2*stride_in] @@ -290,15 +309,13 @@ function fft_pow3!(out::AbstractVector{T}, in::AbstractVector{U}, N::Int, start_ # Size of subproblem Nprime = N ÷ 3 - Rtype = real(T) - dir = twiddle_direction(w) - # Recursive sub-problem step cispi(dir · 2 / Nprime) = w^3. - w_sub = cispi(dir * Rtype(2) / Rtype(Nprime)) - # Dividing into subproblems - fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, w_sub, minus120) - fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, w_sub, minus120) - fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, w_sub, minus120) + fft_pow3!(out, in, Nprime, start_out, stride_out, start_in, stride_in*3, minus120, d) + fft_pow3!(out, in, Nprime, start_out + Nprime*stride_out, stride_out, start_in + stride_in, stride_in*3, minus120, d) + fft_pow3!(out, in, Nprime, start_out + 2*Nprime*stride_out, stride_out, start_in + 2*stride_in, stride_in*3, minus120, d) + + Rtype = real(T) + dir = direction_sign(d) z1 = singleton_params(dir * Rtype(1) / Rtype(N)) z2 = singleton_params(dir * Rtype(2) / Rtype(N)) @@ -381,14 +398,13 @@ function fft_bluestein!( a_series[i] = in[start_in+(i-1)*stride_in] * conj(b_series[i]) end - w_pad = cispi(T(2) / pad_len) # leave b_n vector alone for last step - fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, w_pad) # Fa - fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, w_pad) # Fb + fft_pow2_radix4!(tmp, a_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fa + fft_pow2_radix4!(a_series, b_series, pad_len, 1, 1, 1, 1, FFT_BACKWARD) # Fb tmp .*= a_series # convolution theorem ifft - fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, conj(w_pad)) + fft_pow2_radix4!(a_series, tmp, pad_len, 1, 1, 1, 1, FFT_FORWARD) conv_a_b = a_series Xk = tmp diff --git a/src/singleton_twiddle.jl b/src/singleton_twiddle.jl index 7b9c3ca..f0f5783 100644 --- a/src/singleton_twiddle.jl +++ b/src/singleton_twiddle.jl @@ -18,14 +18,6 @@ # order of magnitude faster than a fresh `cispi`), and the extra trig # (`sincospi(θ/2)`) happens once per kernel call. -# Direction lives in `sign(imag(w))`; when `w` is real (N = 2 or any -# degenerate case where `imag` rounds to zero) both directions collapse -# to the same twiddle set so we pick +1. -@inline function twiddle_direction(w::Complex{T}) where {T<:Real} - s = imag(w) - copysign(one(T), s) -end - # Recurrence coefficients for stepping by `cispi(freq) = e^(iπ·freq)`. # Uses `sincospi(hfreq)` so that `α` and `β` are exact-to-ULP even # for very small frequencies — writing `1 - cos(θ)` directly suffers diff --git a/test/onedim/real_forward.jl b/test/onedim/real_forward.jl index 3c0d8c1..4350756 100644 --- a/test/onedim/real_forward.jl +++ b/test/onedim/real_forward.jl @@ -20,7 +20,7 @@ end @testset "temporarily test real dft separately until used by rfft" begin y_dft = similar(y) - FFTA.fft_dft!(y_dft, x, n, 1, 1, 1, 1, cispi(-2/n)) + FFTA.fft_dft!(y_dft, x, n, 1, 1, 1, 1, FFTA.FFT_FORWARD) @test y ≈ y_dft end