Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
206 changes: 99 additions & 107 deletions src/domain.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@

use ff::PrimeField;
use group::cofactor::CofactorCurve;
use rayon::join;

use super::SynthesisError;

Expand Down Expand Up @@ -259,116 +260,106 @@ impl<S: PrimeField> Group<S> for Scalar<S> {
}

fn best_fft<S: PrimeField, T: Group<S>>(a: &mut [T], worker: &Worker, omega: &S, log_n: u32) {
let log_cpus = worker.log_num_threads();

if log_n <= log_cpus {
serial_fft(a, omega, log_n);
} else {
parallel_fft(a, worker, omega, log_n, log_cpus);
let n = a.len();
assert_eq!(n, 1 << log_n);
if log_n == 0 {
return;
}
}
let log_cpus = worker.log_num_threads();

#[allow(clippy::many_single_char_names)]
fn serial_fft<S: PrimeField, T: Group<S>>(a: &mut [T], omega: &S, log_n: u32) {
fn bitreverse(mut n: u32, l: u32) -> u32 {
let mut r = 0;
for _ in 0..l {
r = (r << 1) | (n & 1);
n >>= 1;
let offset = 64 - log_n;
for i in 0..n as u64 {
let ri = i.reverse_bits() >> offset;
if i < ri {
a.swap(ri as usize, i as usize);
}
r
}

let n = a.len() as u32;
assert_eq!(n, 1 << log_n);
// precompute twiddle factors
let twiddles: Vec<_> = (0..(n / 2))
.scan(S::ONE, |w, _| {
let tw = *w;
*w *= omega;
Some(tw)
})
.collect();

for k in 0..n {
let rk = bitreverse(k, log_n);
if k < rk {
a.swap(rk as usize, k as usize);
}
if log_n <= log_cpus {
serial_fft(a, n, log_n, &twiddles);
} else {
parallel_fft(a, n, 1, &twiddles);
}
}

let mut m = 1;
#[allow(clippy::many_single_char_names)]
fn serial_fft<S: PrimeField, T: Group<S>>(a: &mut [T], n: usize, log_n: u32, twiddles: &[S]) {
let mut chunk = 2_usize;
let mut twiddle_chunk = n / 2;
for _ in 0..log_n {
let w_m = omega.pow_vartime(&[u64::from(n / (2 * m))]);

let mut k = 0;
while k < n {
let mut w = S::ONE;
for j in 0..m {
let mut t = a[(k + j + m) as usize];
t.group_mul_assign(&w);
let mut tmp = a[(k + j) as usize];
tmp.group_sub_assign(&t);
a[(k + j + m) as usize] = tmp;
a[(k + j) as usize].group_add_assign(&t);
w.mul_assign(&w_m);
}

k += 2 * m;
}

m *= 2;
a.chunks_mut(chunk).for_each(|coeffs| {
let (left, right) = coeffs.split_at_mut(chunk / 2);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0].group_add_assign(&t);
b[0].group_sub_assign(&t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t.group_mul_assign(&twiddles[(i + 1) * twiddle_chunk as usize]);
*b = *a;
a.group_add_assign(&t);
b.group_sub_assign(&t);
});
});
chunk *= 2;
twiddle_chunk /= 2;
}
}

fn parallel_fft<S: PrimeField, T: Group<S>>(
pub fn parallel_fft<S: PrimeField, T: Group<S>>(
a: &mut [T],
worker: &Worker,
omega: &S,
log_n: u32,
log_cpus: u32,
n: usize,
twiddle_chunk: usize,
twiddles: &[S],
) {
assert!(log_n >= log_cpus);

let num_cpus = 1 << log_cpus;
let log_new_n = log_n - log_cpus;
let mut tmp = vec![vec![T::group_zero(); 1 << log_new_n]; num_cpus];
let new_omega = omega.pow_vartime(&[num_cpus as u64]);

worker.scope(0, |scope, _| {
let a = &*a;

for (j, tmp) in tmp.iter_mut().enumerate() {
scope.spawn(move |_scope| {
// Shuffle into a sub-FFT
let omega_j = omega.pow_vartime(&[j as u64]);
let omega_step = omega.pow_vartime(&[(j as u64) << log_new_n]);

let mut elt = S::ONE;
for (i, tmp) in tmp.iter_mut().enumerate() {
for s in 0..num_cpus {
let idx = (i + (s << log_new_n)) % (1 << log_n);
let mut t = a[idx];
t.group_mul_assign(&elt);
tmp.group_add_assign(&t);
elt.mul_assign(&omega_step);
}
elt.mul_assign(&omega_j);
}

// Perform sub-FFT
serial_fft(tmp, &new_omega, log_new_n);
});
}
});

// TODO: does this hurt or help?
worker.scope(a.len(), |scope, chunk| {
let tmp = &tmp;

for (idx, a) in a.chunks_mut(chunk).enumerate() {
scope.spawn(move |_scope| {
let mut idx = idx * chunk;
let mask = (1 << log_cpus) - 1;
for a in a {
*a = tmp[idx & mask][idx >> log_cpus];
idx += 1;
}
if n == 2 {
let t = a[1];
a[1] = a[0];
a[0].group_add_assign(&t);
a[1].group_sub_assign(&t);
} else {
let (left, right) = a.split_at_mut(n / 2);
join(
|| parallel_fft(left, n / 2, twiddle_chunk * 2, twiddles),
|| parallel_fft(right, n / 2, twiddle_chunk * 2, twiddles),
);

// case when twiddle factor is one
let (a, left) = left.split_at_mut(1);
let (b, right) = right.split_at_mut(1);
let t = b[0];
b[0] = a[0];
a[0].group_add_assign(&t);
b[0].group_sub_assign(&t);

left.iter_mut()
.zip(right.iter_mut())
.enumerate()
.for_each(|(i, (a, b))| {
let mut t = *b;
t.group_mul_assign(&twiddles[(i + 1) * twiddle_chunk]);
*b = *a;
a.group_add_assign(&t);
b.group_sub_assign(&t);
});
}
});
}
}

// Test multiplying various (low degree) polynomials together and
Expand Down Expand Up @@ -467,27 +458,28 @@ fn fft_composition() {
fn parallel_fft_consistency() {
use bls12_381::Scalar as Fr;
use rand_core::RngCore;
use std::cmp::min;

fn test_consistency<S: PrimeField, R: RngCore>(mut rng: &mut R) {
let worker = Worker::new();

for _ in 0..5 {
for log_d in 0..10 {
let d = 1 << log_d;
for log_n in 1..10 {
let n = 1 << log_n;

let v1 = (0..d)
let v1 = (0..n)
.map(|_| Scalar::<S>(S::random(&mut rng)))
.collect::<Vec<_>>();
let mut v1 = EvaluationDomain::from_coeffs(v1).unwrap();
let mut v2 = EvaluationDomain::from_coeffs(v1.coeffs.clone()).unwrap();
let twiddles: Vec<_> = (0..(n / 2))
.scan(S::ONE, |w, _| {
let tw = *w;
*w *= v2.omega;
Some(tw)
})
.collect();

for log_cpus in log_d..min(log_d + 1, 3) {
parallel_fft(&mut v1.coeffs, &worker, &v1.omega, log_d, log_cpus);
serial_fft(&mut v2.coeffs, &v2.omega, log_d);

assert!(v1.coeffs == v2.coeffs);
}
serial_fft(&mut v1.coeffs, n, log_n, &twiddles);
parallel_fft(&mut v2.coeffs, n, 1, &twiddles);
assert!(v1.coeffs == v2.coeffs);
}
}
}
Expand Down