Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benches/cellgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ pub fn bench_cellgrid(c: &mut Criterion) {
let cutoff_squared = cutoff.powi(2);
b.iter(|| {
cg.particle_pairs()
.filter(|&((_i, p), (_j, q))| {
.filter(|&(&(_i, p), &(_j, q))| {
distance_squared(&(*p).into(), &(*q).into()) <= cutoff_squared
})
.for_each(|_| {});
Expand All @@ -97,7 +97,7 @@ pub fn bench_cellgrid(c: &mut Criterion) {
|b, cg| {
let cutoff_squared = cutoff.powi(2);
b.iter(|| {
cg.par_particle_pairs().for_each(|((_i, p), (_j, q))| {
cg.par_particle_pairs().for_each(|(&(_i, p), &(_j, q))| {
if distance_squared(&(*p).into(), &(*q).into()) <= cutoff_squared {
} else {
}
Expand Down
2 changes: 1 addition & 1 deletion benches/iters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub fn bench_iters(c: &mut Criterion) {
let cutoff_squared = cutoff.powi(2);
b.iter(|| {
pool.install(|| {
cg.par_particle_pairs().for_each(|((_i, p), (_j, q))| {
cg.par_particle_pairs().for_each(|(&(_i, p), &(_j, q))| {
if distance_squared(&(*p).into(), &(*q).into()) <= cutoff_squared {
} else {
}
Expand Down
4 changes: 2 additions & 2 deletions benches/lj.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ pub fn bench_lj(c: &mut Criterion) {
);
let potential_energy: F32or64 = cg
.particle_pairs()
.filter_map(|((_i, p), (_j, q))| {
.filter_map(|(&(_i, p), &(_j, q))| {
let dsq = distance_squared(&(*p).into(), &(*q).into());
if dsq < cutoff_squared {
Some(dsq)
Expand Down Expand Up @@ -109,7 +109,7 @@ pub fn bench_lj(c: &mut Criterion) {
);
let _potential_energy: F32or64 = cg
.particle_pairs()
.filter_map(|((_i, p), (_j, q))| {
.filter_map(|(&(_i, p), &(_j, q))| {
let dsq = distance_squared(&(*p).into(), &(*q).into());
if dsq < cutoff_squared {
Some(dsq)
Expand Down
4 changes: 2 additions & 2 deletions examples/minimal.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,15 +53,15 @@ fn main() {
#[cfg(not(feature = "rayon"))]
// let count = cg.point_pairs().count();
cg.particle_pairs()
.filter(|&((_i, p), (_j, q))| {
.filter(|&(&(_i, p), &(_j, q))| {
distance_squared(&(*p).into(), &(*q).into()) <= _cutoff_squared
})
.for_each(|_| black_box(()));
// cg.rebuild_mut(pointcloud.iter().rev().map(|p| p.coords), None);

#[cfg(feature = "rayon")]
cg.par_particle_pairs()
.filter(|&((_i, p), (_j, q))| {
.filter(|&(&(_i, p), &(_j, q))| {
distance_squared(&(*p).into(), &(*q).into()) <= _cutoff_squared
})
.for_each(|_| {
Expand Down
6 changes: 3 additions & 3 deletions python/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ impl PyCellGrid {
std::array::from_fn(|i| coordinates[i] - other[i]).map(|diff| diff * diff);
x + y + z <= cutoff_squared
});
Some(out.collect())
Some(out.copied().collect())
})
}

Expand Down Expand Up @@ -286,7 +286,7 @@ impl PyCellGridIter {
// let _owner = (&py_cellgrid)
// .into_py_any(py)
// .expect("could not store owner internally");
let iter = Box::new((&py_cellgrid).inner.particle_pairs());
let iter = Box::new((&py_cellgrid).inner.particle_pairs().map(|(&p, &q)| (p, q)));
// SAFETY: the idea is that `_keep_borrow` makes sure that `iter`s lifetime can be extended
// SAFETY: replicating some ideas from
// SAFETY: https://github.com/PyO3/pyo3/issues/1085 and
Expand Down Expand Up @@ -366,7 +366,7 @@ impl PyCellQueryIter {
coordinates: Borrowed<'_, '_, PyAny>,
) -> Option<Self> {
let coordinates = <[f64; 3] as FromPyObject>::extract(coordinates).ok()?;
let iter = Box::new((&py_cellgrid).inner.query_neighbors(coordinates)?);
let iter = Box::new((&py_cellgrid).inner.query_neighbors(coordinates)?.copied());
// SAFETY: see PyCellGridIter
let iter = unsafe {
std::mem::transmute::<
Expand Down
40 changes: 25 additions & 15 deletions src/cellgrid.rs
Original file line number Diff line number Diff line change
Expand Up @@ -140,8 +140,6 @@ where
+ ConstZero
+ AsPrimitive<i32>
+ SimdPartialOrd
+ Send
+ Sync
+ std::fmt::Debug
+ Default,
{
Expand Down Expand Up @@ -230,7 +228,7 @@ where
// FIXME: would just have to make sure that cell is always unique when operating on chunks
// FIXME: (pretty much the same issue as with counting cell sizes concurrently)
cell_lists.push(
particle,
particle.clone(),
cells
.get_mut(cell)
.expect("cell grid should contain every cell in the grid index"),
Expand Down Expand Up @@ -310,7 +308,7 @@ where
//TODO: see `::rebuild()`
.for_each(|(cell, particle)| {
self.cell_lists.push(
particle,
particle.clone(),
self.cells
.get_mut(cell)
.expect("cell grid should contain every cell in the grid index"),
Expand All @@ -322,8 +320,7 @@ where

impl<P: ParticleLike<[T; N]>, const N: usize, T> CellGrid<P, N, T>
where
T: Float + ConstOne + AsPrimitive<i32> + std::fmt::Debug + NumAssignOps + Send + Sync,
P: Send + Sync,
T: Float + ConstOne + AsPrimitive<i32> + std::fmt::Debug + NumAssignOps,
{
/// Returns an iterator over all relevant (i.e. within cutoff threshold + some extra) unique
/// pairs of particles in this `CellGrid`.
Expand All @@ -336,15 +333,15 @@ where
/// let cell_grid = CellGrid::new(data.iter().copied().enumerate(), 1.0);
/// cell_grid.particle_pairs()
/// // usually, .filter_map() is preferable (so distance computations can be re-used)
/// .filter(|&((_i, p), (_j, q))| {
/// .filter(|&(&(_i, p), &(_j, q))| {
/// distance_squared(&p.into(), &q.into()) <= 1.0
/// })
/// .for_each(|((_i, p), (_j, q))| {
/// /* do some work */
/// });
/// ```
#[must_use = "iterators are lazy and do nothing unless consumed"]
pub fn particle_pairs(&self) -> impl Iterator<Item = (P, P)> + Clone {
pub fn particle_pairs(&self) -> impl Iterator<Item = (&P, &P)> + Clone {
self.iter().flat_map(|cell| cell.particle_pairs())
}

Expand Down Expand Up @@ -390,7 +387,7 @@ where
/// .expect("the queried particle should be within `cutoff` of this grid's shape")
/// // usually, .filter_map() is preferable (so distance computations can be re-used)
/// .filter(|&(_j, q)| {
/// distance_squared(&p.into(), &q.into()) <= 1.0
/// distance_squared(&p.into(), &(*q).into()) <= 1.0
/// })
/// .for_each(|(_j, q)| {
/// /* do some work */
Expand All @@ -400,20 +397,33 @@ where
pub fn query_neighbors<Q: ParticleLike<[T; N]>>(
&self,
particle: Q,
) -> Option<impl Iterator<Item = P> + Clone> {
) -> Option<impl Iterator<Item = &P> + Clone> {
self.query(particle).map(|this| {
this.iter().copied().chain(
this.iter().chain(
this.neighbors::<neighborhood::Full>()
.flat_map(|cell| cell.iter().copied()),
.flat_map(|cell| cell.iter()),
)
})
}

/// Returns a slice of the internal cell storage.
///
/// <div class="warning">
///
/// This is an experimental item.
/// It might get removed in the future or its usage might change.
///
/// </div>
#[doc(hidden)]
pub fn cell_storage(&self) -> &[P] {
&self.cell_lists.buffer
}
}

#[cfg(feature = "rayon")]
impl<P, const N: usize, T> CellGrid<P, N, T>
where
T: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + Send + Sync + std::fmt::Debug,
T: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + Sync + std::fmt::Debug,
P: ParticleLike<[T; N]> + Send + Sync,
{
/// Returns a parallel iterator over all relevant (i.e. within cutoff threshold + some extra)
Expand All @@ -434,13 +444,13 @@ where
/// cell_grid.par_particle_pairs()
// TODO: fact-check the statement below:
/// // Try to avoid filtering this ParallelIterator to avoid significant overhead:
/// .for_each(|((_i, p), (_j, q))| {
/// .for_each(|(&(_i, p), &(_j, q))| {
/// if distance_squared(&p.into(), &q.into()) <= 1.0 {
/// /* do some work */
/// }
/// });
/// ```
pub fn par_particle_pairs(&self) -> impl ParallelIterator<Item = (P, P)> {
pub fn par_particle_pairs(&self) -> impl ParallelIterator<Item = (&P, &P)> {
// TODO: ideally, we would schedule 2 threads for cell.particle_pairs() with the same CPU affinity
// TODO: so they can share their resources
self.par_iter().flat_map_iter(|cell| cell.particle_pairs())
Expand Down
33 changes: 19 additions & 14 deletions src/cellgrid/iters.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ pub mod neighborhood {
}

/// `GridCell` represents a possibly empty (by construction) cell of a [`CellGrid`].
#[derive(Debug, Clone, Copy)]
#[derive(Debug, Clone)]
pub struct GridCell<'g, P, const N: usize = 3, F: Float = f64>
where
F: NumAssignOps + ConstOne + AsPrimitive<i32> + std::fmt::Debug,
Expand All @@ -103,10 +103,17 @@ where
pub(crate) index: i32,
}

impl<'g, P, const N: usize, F> Copy for GridCell<'g, P, N, F>
where
F: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + std::fmt::Debug,
P: ParticleLike<[F; N]>,
{
}

impl<'g, P, const N: usize, F> GridCell<'g, P, N, F>
where
F: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + Send + Sync + std::fmt::Debug,
P: ParticleLike<[F; N]> + Send + Sync,
F: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + std::fmt::Debug,
P: ParticleLike<[F; N]>,
{
/// Returns the (flat) cell index of this (possibly empty) `GridCell`.
pub(crate) fn index(&self) -> i32 {
Expand Down Expand Up @@ -182,23 +189,20 @@ where

/// Returns an iterator over all unique pairs of points in this `GridCell`.
#[inline]
fn intra_cell_pairs(self) -> impl FusedIterator<Item = (P, P)> + Clone {
fn intra_cell_pairs(self) -> impl FusedIterator<Item = (&'g P, &'g P)> + Clone {
// this is equivalent to
// self.iter().copied().tuple_combinations::<(P, P)>()
// but faster for our specific case (pairs from slice of `Copy` values)
self.iter()
.copied()
.enumerate()
.flat_map(move |(n, i)| self.iter().copied().skip(n + 1).map(move |j| (i, j)))
.flat_map(move |(n, i)| self.iter().skip(n + 1).map(move |j| (i, j)))
}

/// Returns an iterator over all unique pairs of points in this `GridCell` with points of the neighboring cells.
#[inline]
fn inter_cell_pairs(self) -> impl FusedIterator<Item = (P, P)> + Clone {
self.iter().copied().cartesian_product(
self.neighbors::<Half>()
.flat_map(|cell| cell.iter().copied()),
)
fn inter_cell_pairs(self) -> impl FusedIterator<Item = (&'g P, &'g P)> + Clone {
self.iter()
.cartesian_product(self.neighbors::<Half>().flat_map(|cell| cell.iter()))
}

/// Returns an iterator over all _relevant_ pairs of particles within in the neighborhood of this `GridCell`.
Expand All @@ -208,14 +212,14 @@ where
/// This method consumes `self` but `GridCell` implements [`Copy`].
//TODO: handle full-space as well
//TODO: document that we're relying on GridCell impl'ing Copy here (so we can safely consume `self`)
pub fn particle_pairs(self) -> impl FusedIterator<Item = (P, P)> + Clone + Send + Sync {
pub fn particle_pairs(self) -> impl FusedIterator<Item = (&'g P, &'g P)> + Clone {
self.intra_cell_pairs().chain(self.inter_cell_pairs())
}
}

impl<P, const N: usize, F> CellGrid<P, N, F>
where
F: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + Send + Sync + std::fmt::Debug,
F: Float + NumAssignOps + ConstOne + AsPrimitive<i32> + std::fmt::Debug,
P: ParticleLike<[F; N]>,
{
/// Returns an iterator over all [`GridCell`]s in this `CellGrid`, excluding empty cells.
Expand Down Expand Up @@ -253,7 +257,8 @@ where
#[cfg(feature = "rayon")]
pub fn par_iter(&self) -> impl ParallelIterator<Item = GridCell<'_, P, N, F>>
where
P: Send + Sync,
P: Sync,
F: Sync,
{
self.cells
.par_keys()
Expand Down
2 changes: 1 addition & 1 deletion src/cellgrid/storage.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ use serde::{Deserialize, Serialize};
#[derive(Debug, Default, Clone)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub(crate) struct CellStorage<T> {
buffer: Vec<T>,
pub(crate) buffer: Vec<T>,
}

impl<T> CellStorage<T> {
Expand Down
14 changes: 11 additions & 3 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -80,9 +80,17 @@ pub use crate::cellgrid::CellGrid;
///
/// This trait is required for types used with [`CellGrid`]
/// which needs to know how to get coordinate data.\
/// Only [`Copy`] types can be used.
///
/// <div class="warning">
///
/// `ParticleLike` is a subtrait of [`Clone`].
/// This allows to use the [_interior mutability_](https://doc.rust-lang.org/stable/std/cell/index.html#when-to-choose-interior-mutability) pattern.
///
/// Usually, [`Copy`] types are preferable (they tend to implement `Clone` by copying).
/// In general, the smaller the type, the better (for the CPU cache).
///
/// </div>
///
/// Note that [`CellGrid`] is more specific than this trait and requires implementing `ParticleLike<[{float}; N]>`.
///
/// We do not provide a blanket implementation for types implementing `Into<[T; N]> + Copy` but
Expand Down Expand Up @@ -121,7 +129,7 @@ pub use crate::cellgrid::CellGrid;
/// }
/// }
/// ```
pub trait ParticleLike<T = [f64; 3]>: Copy {
pub trait ParticleLike<T = [f64; 3]>: Clone {
/// Returns a copy of this particle's coordinates.
fn coords(&self) -> T;
}
Expand Down Expand Up @@ -216,7 +224,7 @@ impl<P> From<P> for Particle<P> {
/// ```
impl<L, P, T, const N: usize> ParticleLike<[T; N]> for (L, P)
where
L: Copy,
L: Clone,
P: ParticleLike<[T; N]>,
{
#[inline]
Expand Down
Loading