From bd0c8ecb322ca97ede80ce7dff523badb44a3420 Mon Sep 17 00:00:00 2001 From: ashWhiteHat Date: Fri, 7 Jul 2023 17:14:25 +0900 Subject: [PATCH] domain: fft optimization --- src/domain.rs | 206 ++++++++++++++++++++++++-------------------------- 1 file changed, 99 insertions(+), 107 deletions(-) diff --git a/src/domain.rs b/src/domain.rs index 358a9a51..a230250a 100644 --- a/src/domain.rs +++ b/src/domain.rs @@ -13,6 +13,7 @@ use ff::PrimeField; use group::cofactor::CofactorCurve; +use rayon::join; use super::SynthesisError; @@ -259,116 +260,106 @@ impl Group for Scalar { } fn best_fft>(a: &mut [T], worker: &Worker, omega: &S, log_n: u32) { - let log_cpus = worker.log_num_threads(); - - if log_n <= log_cpus { - serial_fft(a, omega, log_n); - } else { - parallel_fft(a, worker, omega, log_n, log_cpus); + let n = a.len(); + assert_eq!(n, 1 << log_n); + if log_n == 0 { + return; } -} + let log_cpus = worker.log_num_threads(); -#[allow(clippy::many_single_char_names)] -fn serial_fft>(a: &mut [T], omega: &S, log_n: u32) { - fn bitreverse(mut n: u32, l: u32) -> u32 { - let mut r = 0; - for _ in 0..l { - r = (r << 1) | (n & 1); - n >>= 1; + let offset = 64 - log_n; + for i in 0..n as u64 { + let ri = i.reverse_bits() >> offset; + if i < ri { + a.swap(ri as usize, i as usize); } - r } - let n = a.len() as u32; - assert_eq!(n, 1 << log_n); + // precompute twiddle factors + let twiddles: Vec<_> = (0..(n / 2)) + .scan(S::ONE, |w, _| { + let tw = *w; + *w *= omega; + Some(tw) + }) + .collect(); - for k in 0..n { - let rk = bitreverse(k, log_n); - if k < rk { - a.swap(rk as usize, k as usize); - } + if log_n <= log_cpus { + serial_fft(a, n, log_n, &twiddles); + } else { + parallel_fft(a, n, 1, &twiddles); } +} - let mut m = 1; +#[allow(clippy::many_single_char_names)] +fn serial_fft>(a: &mut [T], n: usize, log_n: u32, twiddles: &[S]) { + let mut chunk = 2_usize; + let mut twiddle_chunk = n / 2; for _ in 0..log_n { - let w_m = omega.pow_vartime(&[u64::from(n / (2 * m))]); - - let mut k = 0; - while k < n { - let mut w = S::ONE; - for j in 0..m { - let mut t = a[(k + j + m) as usize]; - t.group_mul_assign(&w); - let mut tmp = a[(k + j) as usize]; - tmp.group_sub_assign(&t); - a[(k + j + m) as usize] = tmp; - a[(k + j) as usize].group_add_assign(&t); - w.mul_assign(&w_m); - } - - k += 2 * m; - } - - m *= 2; + a.chunks_mut(chunk).for_each(|coeffs| { + let (left, right) = coeffs.split_at_mut(chunk / 2); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0].group_add_assign(&t); + b[0].group_sub_assign(&t); + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t.group_mul_assign(&twiddles[(i + 1) * twiddle_chunk as usize]); + *b = *a; + a.group_add_assign(&t); + b.group_sub_assign(&t); + }); + }); + chunk *= 2; + twiddle_chunk /= 2; } } -fn parallel_fft>( +pub fn parallel_fft>( a: &mut [T], - worker: &Worker, - omega: &S, - log_n: u32, - log_cpus: u32, + n: usize, + twiddle_chunk: usize, + twiddles: &[S], ) { - assert!(log_n >= log_cpus); - - let num_cpus = 1 << log_cpus; - let log_new_n = log_n - log_cpus; - let mut tmp = vec![vec![T::group_zero(); 1 << log_new_n]; num_cpus]; - let new_omega = omega.pow_vartime(&[num_cpus as u64]); - - worker.scope(0, |scope, _| { - let a = &*a; - - for (j, tmp) in tmp.iter_mut().enumerate() { - scope.spawn(move |_scope| { - // Shuffle into a sub-FFT - let omega_j = omega.pow_vartime(&[j as u64]); - let omega_step = omega.pow_vartime(&[(j as u64) << log_new_n]); - - let mut elt = S::ONE; - for (i, tmp) in tmp.iter_mut().enumerate() { - for s in 0..num_cpus { - let idx = (i + (s << log_new_n)) % (1 << log_n); - let mut t = a[idx]; - t.group_mul_assign(&elt); - tmp.group_add_assign(&t); - elt.mul_assign(&omega_step); - } - elt.mul_assign(&omega_j); - } - - // Perform sub-FFT - serial_fft(tmp, &new_omega, log_new_n); - }); - } - }); - - // TODO: does this hurt or help? - worker.scope(a.len(), |scope, chunk| { - let tmp = &tmp; - - for (idx, a) in a.chunks_mut(chunk).enumerate() { - scope.spawn(move |_scope| { - let mut idx = idx * chunk; - let mask = (1 << log_cpus) - 1; - for a in a { - *a = tmp[idx & mask][idx >> log_cpus]; - idx += 1; - } + if n == 2 { + let t = a[1]; + a[1] = a[0]; + a[0].group_add_assign(&t); + a[1].group_sub_assign(&t); + } else { + let (left, right) = a.split_at_mut(n / 2); + join( + || parallel_fft(left, n / 2, twiddle_chunk * 2, twiddles), + || parallel_fft(right, n / 2, twiddle_chunk * 2, twiddles), + ); + + // case when twiddle factor is one + let (a, left) = left.split_at_mut(1); + let (b, right) = right.split_at_mut(1); + let t = b[0]; + b[0] = a[0]; + a[0].group_add_assign(&t); + b[0].group_sub_assign(&t); + + left.iter_mut() + .zip(right.iter_mut()) + .enumerate() + .for_each(|(i, (a, b))| { + let mut t = *b; + t.group_mul_assign(&twiddles[(i + 1) * twiddle_chunk]); + *b = *a; + a.group_add_assign(&t); + b.group_sub_assign(&t); }); - } - }); + } } // Test multiplying various (low degree) polynomials together and @@ -467,27 +458,28 @@ fn fft_composition() { fn parallel_fft_consistency() { use bls12_381::Scalar as Fr; use rand_core::RngCore; - use std::cmp::min; fn test_consistency(mut rng: &mut R) { - let worker = Worker::new(); - for _ in 0..5 { - for log_d in 0..10 { - let d = 1 << log_d; + for log_n in 1..10 { + let n = 1 << log_n; - let v1 = (0..d) + let v1 = (0..n) .map(|_| Scalar::(S::random(&mut rng))) .collect::>(); let mut v1 = EvaluationDomain::from_coeffs(v1).unwrap(); let mut v2 = EvaluationDomain::from_coeffs(v1.coeffs.clone()).unwrap(); + let twiddles: Vec<_> = (0..(n / 2)) + .scan(S::ONE, |w, _| { + let tw = *w; + *w *= v2.omega; + Some(tw) + }) + .collect(); - for log_cpus in log_d..min(log_d + 1, 3) { - parallel_fft(&mut v1.coeffs, &worker, &v1.omega, log_d, log_cpus); - serial_fft(&mut v2.coeffs, &v2.omega, log_d); - - assert!(v1.coeffs == v2.coeffs); - } + serial_fft(&mut v1.coeffs, n, log_n, &twiddles); + parallel_fft(&mut v2.coeffs, n, 1, &twiddles); + assert!(v1.coeffs == v2.coeffs); } } }