diff --git a/Cargo.lock b/Cargo.lock index e7c24bd1..a130007c 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -170,6 +170,7 @@ checksum = "379026ff283facf611b0ea629334361c4211d1b12ee01024eec1591133b04120" dependencies = [ "anstyle", "clap_lex", + "strsim", ] [[package]] @@ -178,6 +179,17 @@ version = "0.7.4" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" +[[package]] +name = "codespan-reporting" +version = "0.12.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "fe6d2e5af09e8c8ad56c969f2157a3d4238cebc7c55f0a517728c38f7b200f81" +dependencies = [ + "serde", + "termcolor", + "unicode-width", +] + [[package]] name = "criterion" version = "0.5.1" @@ -245,11 +257,75 @@ version = "0.2.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "43da5946c66ffcc7745f48db692ffbb10a83bfe0afd96235c5c2a4fb23994929" +[[package]] +name = "cxx" +version = "1.0.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "83debbf9cba437faafaf79670021b1b39052a11ce7d5940a1b6befe8a12ba6e9" +dependencies = [ + "cc", + "cxxbridge-cmd", + "cxxbridge-flags", + "cxxbridge-macro", + "foldhash", + "link-cplusplus", +] + +[[package]] +name = "cxx-build" +version = "1.0.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "12353e0a5cd7ecf2d8edd0613a7ef1c7d87a7c60e72b336fac160e81bed78e9c" +dependencies = [ + "cc", + "codespan-reporting", + "indexmap", + "proc-macro2", + "quote", + "scratch", + "syn", +] + +[[package]] +name = "cxxbridge-cmd" +version = "1.0.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6e3b69decbbc971cb79175299aebdc39d2db33dcb13624134a5abb6a4b08e411" +dependencies = [ + "clap", + "codespan-reporting", + "indexmap", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "cxxbridge-flags" +version = "1.0.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69af04f14e7748460723ed28f0c4eed95a123e5e9ad2b46624ef03e1cfd280d6" + +[[package]] +name = "cxxbridge-macro" +version = "1.0.177" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "831941e2914c1ce51336c45e489ca514c0fc3ad175287619bdd9fe7e8a63c891" +dependencies = [ + "indexmap", + "proc-macro2", + "quote", + "rustversion", + "syn", +] + [[package]] name = "differt-core" version = "0.8.0" dependencies = [ "criterion", + "cxx", + "cxx-build", "indexmap", "log", "nalgebra", @@ -258,6 +334,7 @@ dependencies = [ "obj-rs", "ply-rs", "pyo3", + "pyo3-build-config", "pyo3-log", "quick-xml", "rstest", @@ -302,6 +379,12 @@ version = "2.3.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "37909eebbb50d72f9059c3b6d82c0463f2ff062c9e95845c43a6c9c0355411be" +[[package]] +name = "foldhash" +version = "0.2.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "77ce24cb58228fbb8aa041425bb1050850ac19177686ea6e0f41a70416f56fdb" + [[package]] name = "futures" version = "0.3.31" @@ -502,6 +585,15 @@ version = "0.2.172" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "d750af042f7ef4f724306de029d18836c26c1765a54a6a3f094cbd23a7267ffa" +[[package]] +name = "link-cplusplus" +version = "1.0.12" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f78c730aaa7d0b9336a299029ea49f9ee53b0ed06e9202e8cb7db9bae7b8c82" +dependencies = [ + "cc", +] + [[package]] name = "linked-hash-map" version = "0.5.6" @@ -1040,6 +1132,12 @@ dependencies = [ "winapi-util", ] +[[package]] +name = "scratch" +version = "1.0.9" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d68f2ec51b097e4c1a75b681a8bec621909b5e91f15bb7b840c4f2f7b01148b2" + [[package]] name = "semver" version = "1.0.26" @@ -1124,6 +1222,12 @@ dependencies = [ "autocfg", ] +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + [[package]] name = "syn" version = "2.0.101" @@ -1154,6 +1258,15 @@ dependencies = [ "windows-sys", ] +[[package]] +name = "termcolor" +version = "1.4.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06794f8f6c5c898b3275aebefa6b8a1cb24cd2c6c79397ab15774837a0bc5755" +dependencies = [ + "winapi-util", +] + [[package]] name = "testing_logger" version = "0.1.1" @@ -1191,6 +1304,12 @@ version = "1.0.18" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "5a5f39404a5da50712a4c1eecf25e90dd62b613502b7e925fd4e4d19b5c96512" +[[package]] +name = "unicode-width" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b4ac048d71ede7ee76d585517add45da530660ef4390e49b098733c6e897f254" + [[package]] name = "unindent" version = "0.2.4" diff --git a/differt-core/Cargo.toml b/differt-core/Cargo.toml index 580ce1c4..f4875c02 100644 --- a/differt-core/Cargo.toml +++ b/differt-core/Cargo.toml @@ -2,7 +2,12 @@ harness = false name = "bench_main" +[build-dependencies] +cxx-build = {version = ">=1.0,<1.0.178"} +pyo3-build-config = {version = "0.25"} + [dependencies] +cxx = {version = ">=1.0,<1.0.178"} indexmap = {version = "2.5.0", features = ["serde"]} log = "0.4" nalgebra = "0.32.3" diff --git a/differt-core/build.rs b/differt-core/build.rs new file mode 100644 index 00000000..8437c8ad --- /dev/null +++ b/differt-core/build.rs @@ -0,0 +1,60 @@ +/// Build script for differt-core. +/// +/// 1. Queries JAX for XLA FFI header locations +/// 2. Compiles the C++ FFI shim via cxx-build +use std::process::exit; + +fn main() { + // Only build FFI when the feature is enabled + use std::{env, path::PathBuf, str::from_utf8}; + + // Find the Python interpreter: + // 1. PYTHON env var (covers VIRTUAL_ENV and explicit overrides) + // 2. pyo3_build_config: the interpreter pyo3 itself was built against + // 3. Fall back to "python3" + let python = env::var("PYTHON") + .ok() + .or_else(|| pyo3_build_config::get().executable.clone()) + .unwrap_or_else(|| "python3".to_owned()); + + // Query JAX for its XLA FFI include directory + let output = std::process::Command::new(&python) + .args([ + "-c", + "from jax.ffi import include_dir; print(include_dir())", + ]) + .output(); + + let include_path = match output { + Ok(ref out) if out.status.success() => { + let path = from_utf8(&out.stdout) + .expect("Invalid UTF-8 from JAX include_dir()") + .trim() + .to_string(); + if path.is_empty() { + None + } else { + Some(PathBuf::from(path)) + } + }, + _ => None, + }; + + if let Some(include_path) = include_path { + println!("cargo:rerun-if-changed=src/ffi.cc"); + println!("cargo:rerun-if-changed=include/ffi.h"); + + cxx_build::bridge("src/accel/ffi.rs") + .file("src/ffi.cc") + .std("c++17") + .include(&include_path) + .include("include") + .compile("differt-ffi"); + } else { + println!( + "cargo:error=JAX not found or missing jax.ffi.include_dir(). Python interpreter used: \ + {python}" + ); + exit(1); + } +} diff --git a/differt-core/include/ffi.h b/differt-core/include/ffi.h new file mode 100644 index 00000000..8ac67c09 --- /dev/null +++ b/differt-core/include/ffi.h @@ -0,0 +1,17 @@ +#pragma once + +#include "xla/ffi/api/ffi.h" + +// BVH nearest-hit: for each ray, find the closest active triangle. +// Inputs: ray_origins [num_rays, 3], ray_directions [num_rays, 3], +// active_mask [num_triangles] (PRED, or [0] for no mask) +// Attrs: bvh_id (u64) +// Outputs: hit_indices [num_rays] (i32), hit_t [num_rays] (f32) +extern "C" XLA_FFI_Error *BvhNearestHit(XLA_FFI_CallFrame *call_frame); + +// BVH get-candidates: for each ray, find candidate triangles with expanded AABBs. +// Inputs: ray_origins [num_rays, 3], ray_directions [num_rays, 3] +// Attrs: bvh_id (u64), expansion (f32), max_candidates (i32) +// Outputs: candidate_indices [num_rays, max_candidates] (i32), +// candidate_counts [num_rays] (i32) +extern "C" XLA_FFI_Error *BvhGetCandidates(XLA_FFI_CallFrame *call_frame); diff --git a/differt-core/pyproject.toml b/differt-core/pyproject.toml index 2cb68cc9..6a1c88ca 100644 --- a/differt-core/pyproject.toml +++ b/differt-core/pyproject.toml @@ -1,6 +1,6 @@ [build-system] build-backend = "maturin" -requires = ["maturin>=1.9.6,<2"] +requires = ["maturin>=1.9.6,<2", "jax>=0.8.1"] [project] authors = [ @@ -31,6 +31,8 @@ bindings = "pyo3" features = ["pyo3/extension-module"] include = [ {path = "src/**/*", format = "sdist"}, + {path = "include/**/*", format = "sdist"}, + {path = "build.rs", format = "sdist"}, {path = "LICENSE.md", format = "sdist"}, {path = "README.md", format = "sdist"}, ] diff --git a/differt-core/python/differt_core/_differt_core/__init__.pyi b/differt-core/python/differt_core/_differt_core/__init__.pyi index 41cf5fc2..e9c5a7cc 100644 --- a/differt-core/python/differt_core/_differt_core/__init__.pyi +++ b/differt-core/python/differt_core/_differt_core/__init__.pyi @@ -3,6 +3,7 @@ from types import ModuleType __version__: str __version_info__: tuple[int, int, int] +accel: ModuleType geometry: ModuleType rt: ModuleType scene: ModuleType diff --git a/differt-core/python/differt_core/_differt_core/accel/__init__.pyi b/differt-core/python/differt_core/_differt_core/accel/__init__.pyi new file mode 100644 index 00000000..e69de29b diff --git a/differt-core/python/differt_core/_differt_core/accel/bvh.pyi b/differt-core/python/differt_core/_differt_core/accel/bvh.pyi new file mode 100644 index 00000000..8fedc016 --- /dev/null +++ b/differt-core/python/differt_core/_differt_core/accel/bvh.pyi @@ -0,0 +1,32 @@ +import numpy as np +from jaxtyping import Float, Int + +class TriangleBvh: + def __init__( + self, triangle_vertices: Float[np.ndarray, "num_triangles 9"] + ) -> None: ... + @property + def num_triangles(self) -> int: ... + @property + def num_nodes(self) -> int: ... + def register(self) -> int: ... + def unregister(self) -> None: ... + def nearest_hit( + self, + ray_origins: Float[np.ndarray, "num_rays 3"], + ray_directions: Float[np.ndarray, "num_rays 3"], + active_mask: np.ndarray | None = None, + ) -> tuple[Int[np.ndarray, " num_rays"], Float[np.ndarray, " num_rays"]]: ... + def get_candidates( + self, + ray_origins: Float[np.ndarray, "num_rays 3"], + ray_directions: Float[np.ndarray, "num_rays 3"], + expansion: float = 0.0, + max_candidates: int = 256, + ) -> tuple[ + Int[np.ndarray, "num_rays max_candidates"], + Int[np.ndarray, " num_rays"], + ]: ... + +def bvh_nearest_hit_capsule() -> object: ... +def bvh_get_candidates_capsule() -> object: ... diff --git a/differt-core/python/differt_core/accel/__init__.py b/differt-core/python/differt_core/accel/__init__.py new file mode 100644 index 00000000..101b70bb --- /dev/null +++ b/differt-core/python/differt_core/accel/__init__.py @@ -0,0 +1 @@ +"""Acceleration structures for ray tracing.""" diff --git a/differt-core/python/differt_core/accel/bvh.py b/differt-core/python/differt_core/accel/bvh.py new file mode 100644 index 00000000..104aa3e0 --- /dev/null +++ b/differt-core/python/differt_core/accel/bvh.py @@ -0,0 +1,7 @@ +__all__ = ("TriangleBvh",) + +from differt_core import _differt_core + +TriangleBvh = _differt_core.accel.bvh.TriangleBvh +bvh_nearest_hit_capsule = _differt_core.accel.bvh.bvh_nearest_hit_capsule +bvh_get_candidates_capsule = _differt_core.accel.bvh.bvh_get_candidates_capsule diff --git a/differt-core/src/accel/bvh.rs b/differt-core/src/accel/bvh.rs new file mode 100644 index 00000000..157ec581 --- /dev/null +++ b/differt-core/src/accel/bvh.rs @@ -0,0 +1,1349 @@ +//! Bounding Volume Hierarchy (BVH) acceleration structure for triangle meshes. +//! +//! Provides surface-area heuristic (SAH)-based BVH construction and +//! two query types: +//! - Nearest-hit: find the closest triangle intersected by each ray +//! (e.g., for shooting and bouncing rays) +//! - Candidate selection: find all triangles whose expanded bounding boxes +//! intersect each ray (for differentiable mode) + +use std::{ + collections::HashMap, + sync::{ + Mutex, + atomic::{AtomicU64, Ordering}, + }, +}; + +use numpy::{PyArray1, PyArray2, PyReadonlyArray1, PyReadonlyArray2, PyUntypedArrayMethods}; +use pyo3::prelude::*; + +// --------------------------------------------------------------------------- +// Geometry primitives +// --------------------------------------------------------------------------- + +#[derive(Clone, Copy, Debug)] +pub(crate) struct Vec3 { + x: f32, + y: f32, + z: f32, +} + +impl Vec3 { + fn new(x: f32, y: f32, z: f32) -> Self { + Self { x, y, z } + } + + pub(crate) fn from_slice(s: &[f32]) -> Self { + Self { + x: s[0], + y: s[1], + z: s[2], + } + } + + fn sub(self, other: Self) -> Self { + Self::new(self.x - other.x, self.y - other.y, self.z - other.z) + } + + fn cross(self, other: Self) -> Self { + Self::new( + self.y * other.z - self.z * other.y, + self.z * other.x - self.x * other.z, + self.x * other.y - self.y * other.x, + ) + } + + fn dot(self, other: Self) -> f32 { + self.x * other.x + self.y * other.y + self.z * other.z + } + + fn min_comp(self, other: Self) -> Self { + Self::new( + self.x.min(other.x), + self.y.min(other.y), + self.z.min(other.z), + ) + } + + fn max_comp(self, other: Self) -> Self { + Self::new( + self.x.max(other.x), + self.y.max(other.y), + self.z.max(other.z), + ) + } +} + +#[derive(Clone, Copy, Debug)] +struct Aabb { + min: Vec3, + max: Vec3, +} + +impl Aabb { + fn empty() -> Self { + Self { + min: Vec3::new(f32::INFINITY, f32::INFINITY, f32::INFINITY), + max: Vec3::new(f32::NEG_INFINITY, f32::NEG_INFINITY, f32::NEG_INFINITY), + } + } + + fn grow_point(&mut self, p: Vec3) { + self.min = self.min.min_comp(p); + self.max = self.max.max_comp(p); + } + + fn grow_aabb(&mut self, other: &Aabb) { + self.min = self.min.min_comp(other.min); + self.max = self.max.max_comp(other.max); + } + + fn expand(&self, amount: f32) -> Aabb { + Aabb { + min: Vec3::new( + self.min.x - amount, + self.min.y - amount, + self.min.z - amount, + ), + max: Vec3::new( + self.max.x + amount, + self.max.y + amount, + self.max.z + amount, + ), + } + } + + fn surface_area(&self) -> f32 { + let d = self.max.sub(self.min); + 2.0 * (d.x * d.y + d.y * d.z + d.z * d.x) + } + + fn centroid(&self) -> Vec3 { + Vec3::new( + 0.5 * (self.min.x + self.max.x), + 0.5 * (self.min.y + self.max.y), + 0.5 * (self.min.z + self.max.z), + ) + } + + /// Ray-AABB intersection test (slab method). + /// Returns true if the ray intersects the box at any t >= 0. + fn intersects_ray(&self, origin: Vec3, inv_dir: Vec3) -> bool { + let t1x = (self.min.x - origin.x) * inv_dir.x; + let t2x = (self.max.x - origin.x) * inv_dir.x; + let t1y = (self.min.y - origin.y) * inv_dir.y; + let t2y = (self.max.y - origin.y) * inv_dir.y; + let t1z = (self.min.z - origin.z) * inv_dir.z; + let t2z = (self.max.z - origin.z) * inv_dir.z; + + let tmin = t1x.min(t2x).max(t1y.min(t2y)).max(t1z.min(t2z)); + let tmax = t1x.max(t2x).min(t1y.max(t2y)).min(t1z.max(t2z)); + + tmax >= tmin.max(0.0) + } +} + +fn axis_component(v: Vec3, axis: usize) -> f32 { + match axis { + 0 => v.x, + 1 => v.y, + _ => v.z, + } +} + +// --------------------------------------------------------------------------- +// Moller-Trumbore ray-triangle intersection (hard boolean, for Rust-side queries) +// --------------------------------------------------------------------------- + +const MT_EPSILON: f32 = 1e-8; + +/// Returns (t, hit) where t is parametric distance, hit indicates valid intersection. +fn ray_triangle_intersect( + origin: Vec3, + direction: Vec3, + v0: Vec3, + v1: Vec3, + v2: Vec3, +) -> (f32, bool) { + let edge1 = v1.sub(v0); + let edge2 = v2.sub(v0); + let h = direction.cross(edge2); + let a = edge1.dot(h); + + if a.abs() < MT_EPSILON { + return (f32::INFINITY, false); + } + + let f = 1.0 / a; + let s = origin.sub(v0); + let u = f * s.dot(h); + + if !(0.0..=1.0).contains(&u) { + return (f32::INFINITY, false); + } + + let q = s.cross(edge1); + let v = f * direction.dot(q); + + if v < 0.0 || u + v > 1.0 { + return (f32::INFINITY, false); + } + + let t = f * edge2.dot(q); + + if t > MT_EPSILON { + (t, true) + } else { + (f32::INFINITY, false) + } +} + +// --------------------------------------------------------------------------- +// BVH node +// --------------------------------------------------------------------------- + +#[derive(Clone, Debug)] +struct BvhNode { + bounds: Aabb, + /// For leaves: index of first triangle in the reordered tri_indices array. + /// For internal nodes: index of the left child (right = left + 1). + left_or_first: u32, + /// For leaves: number of triangles. For internal nodes: 0. + count: u32, +} + +impl BvhNode { + fn is_leaf(&self) -> bool { + self.count > 0 + } +} + +// --------------------------------------------------------------------------- +// BVH +// --------------------------------------------------------------------------- + +const NUM_SAH_BINS: usize = 12; +const MAX_LEAF_SIZE: u32 = 4; + +pub(crate) struct Bvh { + nodes: Vec, + tri_indices: Vec, + /// Triangle vertices: [num_triangles, 3 vertices, 3 coords] flattened + tri_verts: Vec<[Vec3; 3]>, + /// Per-triangle bounding boxes (precomputed) + tri_bounds: Vec, + /// Per-triangle centroids (precomputed) + tri_centroids: Vec, + nodes_used: u32, +} + +impl Bvh { + fn new(vertices: &[[f32; 9]]) -> Self { + let n = vertices.len(); + let mut tri_verts = Vec::with_capacity(n); + let mut tri_bounds = Vec::with_capacity(n); + let mut tri_centroids = Vec::with_capacity(n); + let tri_indices: Vec = (0..n as u32).collect(); + + for verts in vertices { + let v0 = Vec3::from_slice(&verts[0..3]); + let v1 = Vec3::from_slice(&verts[3..6]); + let v2 = Vec3::from_slice(&verts[6..9]); + tri_verts.push([v0, v1, v2]); + + let mut bb = Aabb::empty(); + bb.grow_point(v0); + bb.grow_point(v1); + bb.grow_point(v2); + tri_bounds.push(bb); + tri_centroids.push(bb.centroid()); + } + + // Allocate worst-case node count (2*n - 1 for binary tree) + let max_nodes = if n > 0 { 2 * n - 1 } else { 1 }; + let mut bvh = Bvh { + nodes: vec![ + BvhNode { + bounds: Aabb::empty(), + left_or_first: 0, + count: 0, + }; + max_nodes + ], + tri_indices, + tri_verts, + tri_bounds, + tri_centroids, + nodes_used: 1, + }; + + // Initialize root + bvh.nodes[0].left_or_first = 0; + bvh.nodes[0].count = n as u32; + bvh.update_node_bounds(0); + bvh.subdivide(0); + + bvh + } + + fn update_node_bounds(&mut self, node_idx: usize) { + let node = &self.nodes[node_idx]; + let first = node.left_or_first as usize; + let count = node.count as usize; + let mut bounds = Aabb::empty(); + for i in first..first + count { + let ti = self.tri_indices[i] as usize; + bounds.grow_aabb(&self.tri_bounds[ti]); + } + self.nodes[node_idx].bounds = bounds; + } + + fn find_best_split(&self, node_idx: usize) -> (usize, f32, f32) { + let node = &self.nodes[node_idx]; + let first = node.left_or_first as usize; + let count = node.count as usize; + + // Compute centroid bounds for binning + let mut centroid_bounds = Aabb::empty(); + for i in first..first + count { + let ti = self.tri_indices[i] as usize; + centroid_bounds.grow_point(self.tri_centroids[ti]); + } + + let mut best_axis = 0; + let mut best_pos = 0.0f32; + let mut best_cost = f32::INFINITY; + + for axis in 0..3 { + let lo = axis_component(centroid_bounds.min, axis); + let hi = axis_component(centroid_bounds.max, axis); + if (hi - lo).abs() < 1e-10 { + continue; + } + + // Binned SAH + let mut bins = [Aabb::empty(); NUM_SAH_BINS]; + let mut bin_counts = [0u32; NUM_SAH_BINS]; + + let scale = NUM_SAH_BINS as f32 / (hi - lo); + + for i in first..first + count { + let ti = self.tri_indices[i] as usize; + let c = axis_component(self.tri_centroids[ti], axis); + let bin = ((c - lo) * scale).min(NUM_SAH_BINS as f32 - 1.0) as usize; + bin_counts[bin] += 1; + bins[bin].grow_aabb(&self.tri_bounds[ti]); + } + + // Sweep from left + let mut left_area = [0.0f32; NUM_SAH_BINS - 1]; + let mut left_count_arr = [0u32; NUM_SAH_BINS - 1]; + let mut left_box = Aabb::empty(); + let mut left_sum = 0u32; + for i in 0..NUM_SAH_BINS - 1 { + left_box.grow_aabb(&bins[i]); + left_sum += bin_counts[i]; + left_area[i] = left_box.surface_area(); + left_count_arr[i] = left_sum; + } + + // Sweep from right + let mut right_box = Aabb::empty(); + let mut right_sum = 0u32; + for i in (1..NUM_SAH_BINS).rev() { + right_box.grow_aabb(&bins[i]); + right_sum += bin_counts[i]; + let cost = left_count_arr[i - 1] as f32 * left_area[i - 1] + + right_sum as f32 * right_box.surface_area(); + if cost < best_cost { + best_cost = cost; + best_axis = axis; + best_pos = lo + i as f32 / scale; + } + } + } + + (best_axis, best_pos, best_cost) + } + + fn subdivide(&mut self, node_idx: usize) { + let count = self.nodes[node_idx].count; + if count <= MAX_LEAF_SIZE { + return; + } + + let (axis, split_pos, split_cost) = self.find_best_split(node_idx); + + // Compare SAH cost to not-split cost + let no_split_cost = count as f32 * self.nodes[node_idx].bounds.surface_area(); + if split_cost >= no_split_cost { + return; + } + + // Partition triangles + let first = self.nodes[node_idx].left_or_first as usize; + let last = first + count as usize; + let mut i = first; + let mut j = last; + + while i < j { + let ti = self.tri_indices[i] as usize; + if axis_component(self.tri_centroids[ti], axis) < split_pos { + i += 1; + } else { + j -= 1; + self.tri_indices.swap(i, j); + } + } + + let left_count = (i - first) as u32; + if left_count == 0 || left_count == count { + return; // degenerate split + } + + let left_child = self.nodes_used as usize; + self.nodes_used += 2; + + self.nodes[left_child].left_or_first = first as u32; + self.nodes[left_child].count = left_count; + + self.nodes[left_child + 1].left_or_first = i as u32; + self.nodes[left_child + 1].count = count - left_count; + + // Convert current node to internal + self.nodes[node_idx].left_or_first = left_child as u32; + self.nodes[node_idx].count = 0; + + self.update_node_bounds(left_child); + self.update_node_bounds(left_child + 1); + self.subdivide(left_child); + self.subdivide(left_child + 1); + } + + /// Find the nearest active triangle hit by a ray. Returns (triangle_index, t) or (-1, inf). + /// + /// When `active_mask` is provided, only triangles where `active_mask[i]` is true + /// are considered. This correctly finds the nearest *active* hit, skipping any + /// inactive triangles that may be closer. + pub(crate) fn nearest_hit( + &self, + origin: Vec3, + direction: Vec3, + active_mask: Option<&[bool]>, + ) -> (i32, f32) { + if self.tri_verts.is_empty() { + return (-1, f32::INFINITY); + } + let inv_dir = Vec3::new(1.0 / direction.x, 1.0 / direction.y, 1.0 / direction.z); + let mut stack = Vec::with_capacity(64); + stack.push(0usize); + + let mut best_t = f32::INFINITY; + let mut best_idx: i32 = -1; + + while let Some(node_idx) = stack.pop() { + let node = &self.nodes[node_idx]; + + if !node.bounds.intersects_ray(origin, inv_dir) { + continue; + } + + if node.is_leaf() { + let first = node.left_or_first as usize; + let count = node.count as usize; + for i in first..first + count { + let ti = self.tri_indices[i] as usize; + if let Some(mask) = active_mask { + if !mask[ti] { + continue; + } + } + let [v0, v1, v2] = self.tri_verts[ti]; + let (t, hit) = ray_triangle_intersect(origin, direction, v0, v1, v2); + if hit && t < best_t { + best_t = t; + best_idx = ti as i32; + } + } + } else { + let left = node.left_or_first as usize; + stack.push(left); + stack.push(left + 1); + } + } + + (best_idx, best_t) + } + + /// Find all candidate triangles whose expanded bounding box intersects a ray. + pub(crate) fn get_candidates( + &self, + origin: Vec3, + direction: Vec3, + expansion: f32, + max_candidates: usize, + ) -> (Vec, u32) { + if self.tri_verts.is_empty() { + return (Vec::new(), 0); + } + let inv_dir = Vec3::new(1.0 / direction.x, 1.0 / direction.y, 1.0 / direction.z); + let mut stack = Vec::with_capacity(64); + stack.push(0usize); + + let mut candidates = Vec::with_capacity(max_candidates.min(256)); + let mut count = 0u32; + + while let Some(node_idx) = stack.pop() { + let node = &self.nodes[node_idx]; + let expanded = node.bounds.expand(expansion); + + if !expanded.intersects_ray(origin, inv_dir) { + continue; + } + + if node.is_leaf() { + let first = node.left_or_first as usize; + let leaf_count = node.count as usize; + for i in first..first + leaf_count { + let ti = self.tri_indices[i] as i32; + if (count as usize) < max_candidates { + candidates.push(ti); + } + count += 1; + } + } else { + let left = node.left_or_first as usize; + stack.push(left); + stack.push(left + 1); + } + } + + (candidates, count) + } +} + +// --------------------------------------------------------------------------- +// PyO3 wrapper +// --------------------------------------------------------------------------- + +// --------------------------------------------------------------------------- +// Global BVH registry for XLA FFI +// --------------------------------------------------------------------------- + +/// Global registry mapping integer IDs to BVH instances. +/// XLA FFI handlers receive the ID as an attribute and look up the BVH here. +static BVH_REGISTRY: Mutex>>> = Mutex::new(None); +static BVH_NEXT_ID: AtomicU64 = AtomicU64::new(1); + +fn registry_init() -> HashMap> { + HashMap::new() +} + +/// Register a BVH and return its ID. +fn registry_insert(bvh: std::sync::Arc) -> u64 { + let id = BVH_NEXT_ID.fetch_add(1, Ordering::Relaxed); + let mut guard = BVH_REGISTRY.lock().unwrap(); + let map = guard.get_or_insert_with(registry_init); + map.insert(id, bvh); + id +} + +/// Remove a BVH from the registry. +fn registry_remove(id: u64) { + let mut guard = BVH_REGISTRY.lock().unwrap(); + if let Some(map) = guard.as_mut() { + map.remove(&id); + } +} + +/// Look up a BVH by ID. Returns None if not found. +#[allow(dead_code)] +pub(crate) fn registry_get(id: u64) -> Option> { + let guard = BVH_REGISTRY.lock().unwrap(); + guard.as_ref().and_then(|map| map.get(&id).cloned()) +} + +// --------------------------------------------------------------------------- +// PyO3 wrapper +// --------------------------------------------------------------------------- + +/// BVH acceleration structure for triangle meshes. +/// +/// Builds a Bounding Volume Hierarchy using the Surface Area Heuristic (SAH) +/// for fast ray-triangle intersection queries. +/// +/// Args: +/// triangle_vertices: Triangle vertices with shape ``(num_triangles, 3, 3)``. +/// +/// Examples: +/// >>> import numpy as np +/// >>> from differt_core.accel.bvh import TriangleBvh +/// >>> # A single triangle +/// >>> verts = np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=np.float32) +/// >>> bvh = TriangleBvh(verts) +/// >>> bvh.num_triangles +/// 1 +#[cfg(not(tarpaulin_include))] +#[pyclass] +struct TriangleBvh { + inner: std::sync::Arc, + registry_id: Option, +} + +#[cfg(not(tarpaulin_include))] +impl Drop for TriangleBvh { + fn drop(&mut self) { + if let Some(id) = self.registry_id.take() { + registry_remove(id); + } + } +} + +#[cfg(not(tarpaulin_include))] +#[pymethods] +impl TriangleBvh { + #[new] + fn new(triangle_vertices: PyReadonlyArray2) -> PyResult { + let shape = triangle_vertices.shape(); + // Expect shape (num_triangles * 3, 3) or we reshape from (num_triangles, 3, 3) + // NumPy 3D arrays are passed as 2D with shape (N*3, 3) when using PyReadonlyArray2 + if shape[1] != 3 { + return Err(pyo3::exceptions::PyValueError::new_err( + "triangle_vertices must have shape (num_triangles * 3, 3) or (num_triangles, 3, 3)", + )); + } + + let data = triangle_vertices.as_slice().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Array must be contiguous: {e}")) + })?; + + let num_verts_rows = shape[0]; + if num_verts_rows % 3 != 0 { + return Err(pyo3::exceptions::PyValueError::new_err( + "First dimension must be divisible by 3 (num_triangles * 3 vertices)", + )); + } + + let num_triangles = num_verts_rows / 3; + let mut flat_tris: Vec<[f32; 9]> = Vec::with_capacity(num_triangles); + + for i in 0..num_triangles { + let base = i * 9; // 3 vertices * 3 coords + flat_tris.push([ + data[base], + data[base + 1], + data[base + 2], + data[base + 3], + data[base + 4], + data[base + 5], + data[base + 6], + data[base + 7], + data[base + 8], + ]); + } + + Ok(Self { + inner: std::sync::Arc::new(Bvh::new(&flat_tris)), + registry_id: None, + }) + } + + /// Number of triangles in the BVH. + #[getter] + fn num_triangles(&self) -> usize { + self.inner.tri_verts.len() + } + + /// Number of BVH nodes used. + #[getter] + fn num_nodes(&self) -> u32 { + self.inner.nodes_used + } + + /// Register this BVH in the global registry for XLA FFI access. + /// + /// Returns: + /// The integer ID that XLA FFI handlers use to look up this BVH. + /// + /// Examples: + /// >>> import numpy as np + /// >>> from differt_core.accel.bvh import TriangleBvh + /// >>> verts = np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=np.float32) + /// >>> bvh = TriangleBvh(verts) + /// >>> bvh_id = bvh.register() + /// >>> bvh_id > 0 + /// True + /// >>> bvh.unregister() + fn register(&mut self) -> u64 { + if let Some(id) = self.registry_id { + return id; + } + let id = registry_insert(self.inner.clone()); + self.registry_id = Some(id); + id + } + + /// Remove this BVH from the global registry. + fn unregister(&mut self) { + if let Some(id) = self.registry_id.take() { + registry_remove(id); + } + } + + /// Find the nearest triangle hit by each ray. + /// + /// Args: + /// ray_origins: Ray origins with shape ``(num_rays, 3)``. + /// ray_directions: Ray directions with shape ``(num_rays, 3)``. + /// + /// Returns: + /// A tuple ``(hit_indices, hit_t)`` where ``hit_indices`` has shape + /// ``(num_rays,)`` with the triangle index (``-1`` if no hit) and + /// ``hit_t`` has shape ``(num_rays,)`` with the parametric distance. + /// + /// Examples: + /// >>> import numpy as np + /// >>> from differt_core.accel.bvh import TriangleBvh + /// >>> verts = np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=np.float32) + /// >>> bvh = TriangleBvh(verts) + /// >>> origins = np.array([[0.1, 0.1, 1.0]], dtype=np.float32) + /// >>> dirs = np.array([[0, 0, -1]], dtype=np.float32) + /// >>> idx, t = bvh.nearest_hit(origins, dirs) + /// >>> int(idx[0]) + /// 0 + /// >>> float(t[0]) + /// 1.0 + /// Find the nearest triangle hit by each ray. + /// + /// Args: + /// ray_origins: Ray origins with shape ``(num_rays, 3)``. + /// ray_directions: Ray directions with shape ``(num_rays, 3)``. + /// active_mask: Optional boolean mask with shape ``(num_triangles,)``. + /// When provided, only triangles where the mask is ``True`` are + /// considered during traversal. + /// + /// Returns: + /// A tuple ``(hit_indices, hit_t)`` where ``hit_indices`` has shape + /// ``(num_rays,)`` with the triangle index (``-1`` if no hit) and + /// ``hit_t`` has shape ``(num_rays,)`` with the parametric distance. + /// + /// Examples: + /// >>> import numpy as np + /// >>> from differt_core.accel.bvh import TriangleBvh + /// >>> verts = np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=np.float32) + /// >>> bvh = TriangleBvh(verts) + /// >>> origins = np.array([[0.1, 0.1, 1.0]], dtype=np.float32) + /// >>> dirs = np.array([[0, 0, -1]], dtype=np.float32) + /// >>> idx, t = bvh.nearest_hit(origins, dirs) + /// >>> int(idx[0]) + /// 0 + /// >>> float(t[0]) + /// 1.0 + #[pyo3(signature = (ray_origins, ray_directions, active_mask=None))] + fn nearest_hit<'py>( + &self, + py: Python<'py>, + ray_origins: PyReadonlyArray2, + ray_directions: PyReadonlyArray2, + active_mask: Option>, + ) -> PyResult<(Bound<'py, PyArray1>, Bound<'py, PyArray1>)> { + let origins = ray_origins.as_slice().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("ray_origins must be contiguous: {e}")) + })?; + let dirs = ray_directions.as_slice().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "ray_directions must be contiguous: {e}" + )) + })?; + + let mask_vec: Option> = match &active_mask { + Some(m) => { + let s: &[bool] = m.as_slice().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "active_mask must be contiguous: {e}" + )) + })?; + Some(s.to_vec()) + }, + None => None, + }; + let mask_slice: Option<&[bool]> = mask_vec.as_deref(); + + let num_rays = ray_origins.shape()[0]; + let mut hit_indices = vec![-1i32; num_rays]; + let mut hit_t = vec![f32::INFINITY; num_rays]; + + for i in 0..num_rays { + let origin = Vec3::from_slice(&origins[i * 3..(i + 1) * 3]); + let dir = Vec3::from_slice(&dirs[i * 3..(i + 1) * 3]); + let (idx, t) = self.inner.nearest_hit(origin, dir, mask_slice); + hit_indices[i] = idx; + hit_t[i] = t; + } + + Ok(( + PyArray1::from_vec(py, hit_indices), + PyArray1::from_vec(py, hit_t), + )) + } + + /// Find candidate triangles whose expanded bounding boxes intersect each ray. + /// + /// This is used for differentiable mode: the expansion captures all triangles + /// with non-negligible gradient contribution. + /// + /// Args: + /// ray_origins: Ray origins with shape ``(num_rays, 3)``. + /// ray_directions: Ray directions with shape ``(num_rays, 3)``. + /// expansion: Bounding box expansion amount (related to smoothing_factor). + /// max_candidates: Maximum number of candidates per ray. + /// + /// Returns: + /// A tuple ``(candidate_indices, candidate_counts)`` where + /// ``candidate_indices`` has shape ``(num_rays, max_candidates)`` padded + /// with ``-1``, and ``candidate_counts`` has shape ``(num_rays,)``. + /// + /// Examples: + /// >>> import numpy as np + /// >>> from differt_core.accel.bvh import TriangleBvh + /// >>> verts = np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=np.float32) + /// >>> bvh = TriangleBvh(verts) + /// >>> origins = np.array([[0.1, 0.1, 1.0]], dtype=np.float32) + /// >>> dirs = np.array([[0, 0, -1]], dtype=np.float32) + /// >>> idx, counts = bvh.get_candidates(origins, dirs, 0.0, 256) + /// >>> int(counts[0]) + /// 1 + /// >>> int(idx[0, 0]) + /// 0 + fn get_candidates<'py>( + &self, + py: Python<'py>, + ray_origins: PyReadonlyArray2, + ray_directions: PyReadonlyArray2, + expansion: f32, + max_candidates: usize, + ) -> PyResult<(Bound<'py, PyArray2>, Bound<'py, PyArray1>)> { + let origins = ray_origins.as_slice().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("ray_origins must be contiguous: {e}")) + })?; + let dirs = ray_directions.as_slice().map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!( + "ray_directions must be contiguous: {e}" + )) + })?; + + let num_rays = ray_origins.shape()[0]; + let mut all_indices = vec![-1i32; num_rays * max_candidates]; + let mut all_counts = vec![0u32; num_rays]; + + for i in 0..num_rays { + let origin = Vec3::from_slice(&origins[i * 3..(i + 1) * 3]); + let dir = Vec3::from_slice(&dirs[i * 3..(i + 1) * 3]); + let (candidates, count) = + self.inner + .get_candidates(origin, dir, expansion, max_candidates); + all_counts[i] = count; + let row_start = i * max_candidates; + for (j, &idx) in candidates.iter().enumerate() { + all_indices[row_start + j] = idx; + } + } + + let indices_array = + numpy::ndarray::Array2::from_shape_vec((num_rays, max_candidates), all_indices) + .map_err(|e| { + pyo3::exceptions::PyValueError::new_err(format!("Shape error: {e}")) + })?; + + Ok(( + PyArray2::from_owned_array(py, indices_array), + PyArray1::from_vec(py, all_counts), + )) + } +} + +#[cfg(not(tarpaulin_include))] +#[pymodule(gil_used = false)] +pub(crate) fn bvh(m: &Bound<'_, PyModule>) -> PyResult<()> { + m.add_class::()?; + m.add_function(pyo3::wrap_pyfunction!( + super::ffi::bvh_nearest_hit_capsule, + m + )?)?; + m.add_function(pyo3::wrap_pyfunction!( + super::ffi::bvh_get_candidates_capsule, + m + )?)?; + Ok(()) +} + +// --------------------------------------------------------------------------- +// Tests +// --------------------------------------------------------------------------- + +#[cfg(test)] +mod tests { + use super::*; + + fn single_triangle() -> Vec<[f32; 9]> { + vec![[0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0]] + } + + fn cube_triangles() -> Vec<[f32; 9]> { + // 12 triangles forming a unit cube [0,1]^3 + let faces: Vec<([f32; 3], [f32; 3], [f32; 3])> = vec![ + // Front face (z=1) + ([0.0, 0.0, 1.0], [1.0, 0.0, 1.0], [1.0, 1.0, 1.0]), + ([0.0, 0.0, 1.0], [1.0, 1.0, 1.0], [0.0, 1.0, 1.0]), + // Back face (z=0) + ([0.0, 0.0, 0.0], [0.0, 1.0, 0.0], [1.0, 1.0, 0.0]), + ([0.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 0.0, 0.0]), + // Top face (y=1) + ([0.0, 1.0, 0.0], [0.0, 1.0, 1.0], [1.0, 1.0, 1.0]), + ([0.0, 1.0, 0.0], [1.0, 1.0, 1.0], [1.0, 1.0, 0.0]), + // Bottom face (y=0) + ([0.0, 0.0, 0.0], [1.0, 0.0, 0.0], [1.0, 0.0, 1.0]), + ([0.0, 0.0, 0.0], [1.0, 0.0, 1.0], [0.0, 0.0, 1.0]), + // Right face (x=1) + ([1.0, 0.0, 0.0], [1.0, 1.0, 0.0], [1.0, 1.0, 1.0]), + ([1.0, 0.0, 0.0], [1.0, 1.0, 1.0], [1.0, 0.0, 1.0]), + // Left face (x=0) + ([0.0, 0.0, 0.0], [0.0, 0.0, 1.0], [0.0, 1.0, 1.0]), + ([0.0, 0.0, 0.0], [0.0, 1.0, 1.0], [0.0, 1.0, 0.0]), + ]; + faces + .into_iter() + .map(|(a, b, c)| [a[0], a[1], a[2], b[0], b[1], b[2], c[0], c[1], c[2]]) + .collect() + } + + #[test] + fn test_bvh_construction_single_triangle() { + let bvh = Bvh::new(&single_triangle()); + assert_eq!(bvh.tri_verts.len(), 1); + assert!(bvh.nodes_used >= 1); + } + + #[test] + fn test_bvh_construction_cube() { + let bvh = Bvh::new(&cube_triangles()); + assert_eq!(bvh.tri_verts.len(), 12); + assert!(bvh.nodes_used >= 1); + } + + #[test] + fn test_bvh_construction_empty() { + let bvh = Bvh::new(&[]); + assert_eq!(bvh.tri_verts.len(), 0); + } + + #[test] + fn test_nearest_hit_single_triangle() { + let bvh = Bvh::new(&single_triangle()); + // Ray pointing down at (0.1, 0.1) + let origin = Vec3::new(0.1, 0.1, 1.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let (idx, t) = bvh.nearest_hit(origin, dir, None); + assert_eq!(idx, 0); + assert!((t - 1.0).abs() < 1e-5); + } + + #[test] + fn test_nearest_hit_miss() { + let bvh = Bvh::new(&single_triangle()); + // Ray pointing away + let origin = Vec3::new(0.1, 0.1, 1.0); + let dir = Vec3::new(0.0, 0.0, 1.0); + let (idx, _t) = bvh.nearest_hit(origin, dir, None); + assert_eq!(idx, -1); + } + + #[test] + fn test_nearest_hit_cube() { + let bvh = Bvh::new(&cube_triangles()); + // Ray from outside hitting front face + let origin = Vec3::new(0.5, 0.5, 2.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let (idx, t) = bvh.nearest_hit(origin, dir, None); + assert!(idx >= 0, "Should hit a front-face triangle"); + assert!( + (t - 1.0).abs() < 1e-5, + "Distance to front face should be 1.0" + ); + } + + #[test] + fn test_nearest_hit_picks_closest() { + let bvh = Bvh::new(&cube_triangles()); + // Ray going through both front and back faces -- should hit front (closer) + let origin = Vec3::new(0.5, 0.5, 2.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let (idx, t) = bvh.nearest_hit(origin, dir, None); + assert!(idx >= 0); + assert!( + (t - 1.0).abs() < 1e-5, + "Should hit front face at t=1, got t={t}" + ); + } + + #[test] + fn test_get_candidates_no_expansion() { + let bvh = Bvh::new(&single_triangle()); + let origin = Vec3::new(0.1, 0.1, 1.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let (candidates, count) = bvh.get_candidates(origin, dir, 0.0, 256); + assert!(count >= 1, "Should find at least the hit triangle"); + assert_eq!(candidates[0], 0); + } + + #[test] + fn test_get_candidates_with_expansion() { + // Many distant triangles to force BVH splits (need > MAX_LEAF_SIZE=4) + let mut tris = Vec::new(); + // 5 triangles near origin + for i in 0..5 { + let x = i as f32 * 0.5; + tris.push([x, 0.0, 0.0, x + 0.4, 0.0, 0.0, x, 0.4, 0.0]); + } + // 5 triangles far away (x=100) + for i in 0..5 { + let x = 100.0 + i as f32 * 0.5; + tris.push([x, 0.0, 0.0, x + 0.4, 0.0, 0.0, x, 0.4, 0.0]); + } + let bvh = Bvh::new(&tris); + + // Ray aimed at near triangles + let origin = Vec3::new(0.1, 0.1, 1.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + + // No expansion: should not include the far-away group + let (_, count_no_exp) = bvh.get_candidates(origin, dir, 0.0, 256); + assert!( + count_no_exp <= 5, + "Without expansion, should not include far triangles, got {count_no_exp}" + ); + + // Large expansion: should include all + let (_, count_large_exp) = bvh.get_candidates(origin, dir, 200.0, 256); + assert_eq!( + count_large_exp, 10, + "With large expansion, should find all 10" + ); + } + + #[test] + fn test_nearest_hit_matches_brute_force() { + let tris = cube_triangles(); + let bvh = Bvh::new(&tris); + + // Test several rays + let rays = vec![ + (Vec3::new(0.5, 0.5, 2.0), Vec3::new(0.0, 0.0, -1.0)), + (Vec3::new(0.5, 0.5, -1.0), Vec3::new(0.0, 0.0, 1.0)), + (Vec3::new(2.0, 0.5, 0.5), Vec3::new(-1.0, 0.0, 0.0)), + (Vec3::new(0.5, 2.0, 0.5), Vec3::new(0.0, -1.0, 0.0)), + (Vec3::new(5.0, 5.0, 5.0), Vec3::new(0.0, 0.0, 1.0)), // miss + ]; + + for (origin, dir) in &rays { + let (bvh_idx, bvh_t) = bvh.nearest_hit(*origin, *dir, None); + + // Brute force + let mut bf_idx = -1i32; + let mut bf_t = f32::INFINITY; + for (ti, tri) in tris.iter().enumerate() { + let v0 = Vec3::from_slice(&tri[0..3]); + let v1 = Vec3::from_slice(&tri[3..6]); + let v2 = Vec3::from_slice(&tri[6..9]); + let (t, hit) = ray_triangle_intersect(*origin, *dir, v0, v1, v2); + if hit && t < bf_t { + bf_t = t; + bf_idx = ti as i32; + } + } + + // Both should agree on hit/miss + assert_eq!( + bvh_idx >= 0, + bf_idx >= 0, + "Hit/miss mismatch for ray {origin:?} -> {dir:?}: bvh={bvh_idx}, bf={bf_idx}" + ); + if bf_idx >= 0 { + // t values must match (indices may differ for coplanar triangles) + assert!( + (bvh_t - bf_t).abs() < 1e-5, + "t mismatch: bvh={bvh_t}, bf={bf_t}" + ); + } + } + } + + #[test] + fn test_nearest_hit_empty_bvh() { + let bvh = Bvh::new(&[]); + let origin = Vec3::new(0.0, 0.0, 1.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let (idx, t) = bvh.nearest_hit(origin, dir, None); + assert_eq!(idx, -1); + assert!(t.is_infinite()); + } + + #[test] + fn test_get_candidates_empty_bvh() { + let bvh = Bvh::new(&[]); + let origin = Vec3::new(0.0, 0.0, 1.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let (candidates, count) = bvh.get_candidates(origin, dir, 1.0, 256); + assert!(candidates.is_empty()); + assert_eq!(count, 0); + } + + #[test] + fn test_nearest_hit_active_mask() { + // Two triangles stacked: tri 0 at z=0, tri 1 at z=-1 + let tris = vec![ + // Triangle 0: at z=0 + [0.0, 0.0, 0.0, 1.0, 0.0, 0.0, 0.0, 1.0, 0.0], + // Triangle 1: at z=-1 (behind tri 0 from above) + [0.0, 0.0, -1.0, 1.0, 0.0, -1.0, 0.0, 1.0, -1.0], + ]; + let bvh = Bvh::new(&tris); + + let origin = Vec3::new(0.1, 0.1, 2.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + + // No mask: hits tri 0 (nearest) + let (idx, t) = bvh.nearest_hit(origin, dir, None); + assert_eq!(idx, 0); + assert!((t - 2.0).abs() < 1e-5); + + // Mask out tri 0: should find tri 1 instead + let mask = vec![false, true]; + let (idx, t) = bvh.nearest_hit(origin, dir, Some(&mask)); + assert_eq!(idx, 1); + assert!((t - 3.0).abs() < 1e-5); + + // Mask out both: no hit + let mask = vec![false, false]; + let (idx, _t) = bvh.nearest_hit(origin, dir, Some(&mask)); + assert_eq!(idx, -1); + } + + // ----------------------------------------------------------------------- + // Registry tests + // ----------------------------------------------------------------------- + + #[test] + fn test_registry_insert_get_remove() { + let bvh = std::sync::Arc::new(Bvh::new(&single_triangle())); + let id = registry_insert(bvh); + assert!(id > 0); + + let retrieved = registry_get(id); + assert!(retrieved.is_some()); + + registry_remove(id); + let gone = registry_get(id); + assert!(gone.is_none()); + } + + #[test] + fn test_registry_get_nonexistent() { + assert!(registry_get(999_999_999).is_none()); + } + + #[test] + fn test_registry_remove_nonexistent() { + // Should not panic + registry_remove(999_999_999); + } + + #[test] + fn test_registry_multiple_entries() { + let bvh1 = std::sync::Arc::new(Bvh::new(&single_triangle())); + let bvh2 = std::sync::Arc::new(Bvh::new(&cube_triangles())); + let id1 = registry_insert(bvh1); + let id2 = registry_insert(bvh2); + assert_ne!(id1, id2); + + assert!(registry_get(id1).is_some()); + assert!(registry_get(id2).is_some()); + + registry_remove(id1); + assert!(registry_get(id1).is_none()); + assert!(registry_get(id2).is_some()); + + registry_remove(id2); + } + + // ----------------------------------------------------------------------- + // Ray-triangle intersection edge cases + // ----------------------------------------------------------------------- + + #[test] + fn test_ray_triangle_parallel() { + let v0 = Vec3::new(0.0, 0.0, 0.0); + let v1 = Vec3::new(1.0, 0.0, 0.0); + let v2 = Vec3::new(0.0, 1.0, 0.0); + // Ray parallel to triangle plane + let (_, hit) = ray_triangle_intersect( + Vec3::new(0.1, 0.1, 0.0), + Vec3::new(1.0, 0.0, 0.0), + v0, + v1, + v2, + ); + assert!(!hit); + } + + #[test] + fn test_ray_triangle_v_negative() { + let v0 = Vec3::new(0.0, 0.0, 0.0); + let v1 = Vec3::new(1.0, 0.0, 0.0); + let v2 = Vec3::new(0.0, 1.0, 0.0); + // Hits plane but v < 0 (outside triangle below edge) + let (_, hit) = ray_triangle_intersect( + Vec3::new(0.5, -0.5, 1.0), + Vec3::new(0.0, 0.0, -1.0), + v0, + v1, + v2, + ); + assert!(!hit); + } + + #[test] + fn test_ray_triangle_uv_sum_exceeds_one() { + let v0 = Vec3::new(0.0, 0.0, 0.0); + let v1 = Vec3::new(1.0, 0.0, 0.0); + let v2 = Vec3::new(0.0, 1.0, 0.0); + // u + v > 1 (near hypotenuse, outside) + let (_, hit) = ray_triangle_intersect( + Vec3::new(0.8, 0.8, 1.0), + Vec3::new(0.0, 0.0, -1.0), + v0, + v1, + v2, + ); + assert!(!hit); + } + + // ----------------------------------------------------------------------- + // BVH get_candidates with max_candidates limit + // ----------------------------------------------------------------------- + + #[test] + fn test_get_candidates_max_limit() { + let bvh = Bvh::new(&cube_triangles()); + let origin = Vec3::new(0.5, 0.5, 2.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + // Limit to 1 candidate + let (candidates, count) = bvh.get_candidates(origin, dir, 10.0, 1); + // count reflects total found, candidates vec is capped + assert!(count >= 1); + assert!(candidates.len() <= 1); + } + + // ----------------------------------------------------------------------- + // Vec3 / Aabb additional coverage + // ----------------------------------------------------------------------- + + #[test] + fn test_vec3_operations() { + let a = Vec3::new(1.0, 2.0, 3.0); + let b = Vec3::new(4.0, 5.0, 6.0); + let diff = a.sub(b); + assert!((diff.x - (-3.0)).abs() < 1e-6); + + let cross = a.cross(b); + // cross(a,b) = (2*6-3*5, 3*4-1*6, 1*5-2*4) = (-3, 6, -3) + assert!((cross.x - (-3.0)).abs() < 1e-6); + assert!((cross.y - 6.0).abs() < 1e-6); + assert!((cross.z - (-3.0)).abs() < 1e-6); + + assert!((a.dot(b) - 32.0).abs() < 1e-6); + } + + #[test] + fn test_aabb_expand_and_intersect() { + let mut bb = Aabb::empty(); + bb.grow_point(Vec3::new(0.0, 0.0, 0.0)); + bb.grow_point(Vec3::new(1.0, 1.0, 1.0)); + + // Ray that misses the original box but hits the expanded one + let origin = Vec3::new(1.5, 0.5, 5.0); + let dir = Vec3::new(0.0, 0.0, -1.0); + let inv_dir = Vec3::new(1.0 / dir.x, 1.0 / dir.y, 1.0 / dir.z); + + assert!(!bb.intersects_ray(origin, inv_dir)); + + let expanded = bb.expand(1.0); + assert!(expanded.intersects_ray(origin, inv_dir)); + } + + #[test] + fn test_axis_component_all_axes() { + let v = Vec3::new(1.0, 2.0, 3.0); + assert!((axis_component(v, 0) - 1.0).abs() < 1e-6); + assert!((axis_component(v, 1) - 2.0).abs() < 1e-6); + assert!((axis_component(v, 2) - 3.0).abs() < 1e-6); + // Default branch (axis >= 3 maps to z) + assert!((axis_component(v, 99) - 3.0).abs() < 1e-6); + } + + #[test] + fn test_bvh_node_is_leaf() { + let leaf = BvhNode { + bounds: Aabb::empty(), + left_or_first: 0, + count: 2, + }; + assert!(leaf.is_leaf()); + + let internal = BvhNode { + bounds: Aabb::empty(), + left_or_first: 1, + count: 0, + }; + assert!(!internal.is_leaf()); + } + + #[test] + fn test_ray_triangle_intersect_basic() { + let v0 = Vec3::new(0.0, 0.0, 0.0); + let v1 = Vec3::new(1.0, 0.0, 0.0); + let v2 = Vec3::new(0.0, 1.0, 0.0); + + // Hit + let (t, hit) = ray_triangle_intersect( + Vec3::new(0.1, 0.1, 1.0), + Vec3::new(0.0, 0.0, -1.0), + v0, + v1, + v2, + ); + assert!(hit); + assert!((t - 1.0).abs() < 1e-5); + + // Miss (outside triangle) + let (_, hit) = ray_triangle_intersect( + Vec3::new(2.0, 2.0, 1.0), + Vec3::new(0.0, 0.0, -1.0), + v0, + v1, + v2, + ); + assert!(!hit); + + // Miss (behind ray) + let (_, hit) = ray_triangle_intersect( + Vec3::new(0.1, 0.1, -1.0), + Vec3::new(0.0, 0.0, -1.0), + v0, + v1, + v2, + ); + assert!(!hit); + } +} diff --git a/differt-core/src/accel/ffi.rs b/differt-core/src/accel/ffi.rs new file mode 100644 index 00000000..e52bd902 --- /dev/null +++ b/differt-core/src/accel/ffi.rs @@ -0,0 +1,153 @@ +//! XLA FFI bridge for BVH acceleration. +//! +//! This module provides the cxx bridge between Rust BVH queries and +//! C++ XLA FFI handlers, enabling BVH queries inside JIT-compiled JAX functions. + +use pyo3::prelude::*; + +use super::bvh::{Vec3, registry_get}; + +#[cxx::bridge] +mod ffi_bridge { + extern "Rust" { + fn bvh_nearest_hit_ffi( + bvh_id: u64, + origins: &[f32], + directions: &[f32], + active_mask: &[u8], + hit_indices: &mut [i32], + hit_t: &mut [f32], + ); + + fn bvh_get_candidates_ffi( + bvh_id: u64, + expansion: f32, + max_candidates: i32, + origins: &[f32], + directions: &[f32], + candidate_indices: &mut [i32], + candidate_counts: &mut [i32], + ); + } + + unsafe extern "C++" { + include!("ffi.h"); + + type XLA_FFI_Error; + type XLA_FFI_CallFrame; + + unsafe fn BvhNearestHit(call_frame: *mut XLA_FFI_CallFrame) -> *mut XLA_FFI_Error; + unsafe fn BvhGetCandidates(call_frame: *mut XLA_FFI_CallFrame) -> *mut XLA_FFI_Error; + } +} + +/// FFI entry point for nearest-hit queries, called from C++ XLA handler. +/// +/// `active_mask` is a byte slice of length `num_triangles` (0 = inactive, nonzero = active). +/// An empty slice means no mask (all triangles active). +fn bvh_nearest_hit_ffi( + bvh_id: u64, + origins: &[f32], + directions: &[f32], + active_mask: &[u8], + hit_indices: &mut [i32], + hit_t: &mut [f32], +) { + let bvh = match registry_get(bvh_id) { + Some(b) => b, + None => { + hit_indices.fill(-1); + hit_t.fill(f32::INFINITY); + return; + }, + }; + + // Convert u8 mask to bool slice (empty = no mask) + let mask_bools: Vec; + let mask_opt = if active_mask.is_empty() { + None + } else { + mask_bools = active_mask.iter().map(|&b| b != 0).collect(); + Some(mask_bools.as_slice()) + }; + + let num_rays = hit_indices.len(); + for i in 0..num_rays { + let origin = Vec3::from_slice(&origins[i * 3..(i + 1) * 3]); + let dir = Vec3::from_slice(&directions[i * 3..(i + 1) * 3]); + let (idx, t) = bvh.nearest_hit(origin, dir, mask_opt); + hit_indices[i] = idx; + hit_t[i] = t; + } +} + +/// FFI entry point for candidate queries, called from C++ XLA handler. +fn bvh_get_candidates_ffi( + bvh_id: u64, + expansion: f32, + max_candidates: i32, + origins: &[f32], + directions: &[f32], + candidate_indices: &mut [i32], + candidate_counts: &mut [i32], +) { + let max_cand = max_candidates as usize; + let bvh = match registry_get(bvh_id) { + Some(b) => b, + None => { + candidate_indices.fill(-1); + candidate_counts.fill(0); + return; + }, + }; + + let num_rays = candidate_counts.len(); + for i in 0..num_rays { + let origin = Vec3::from_slice(&origins[i * 3..(i + 1) * 3]); + let dir = Vec3::from_slice(&directions[i * 3..(i + 1) * 3]); + let (candidates, count) = bvh.get_candidates(origin, dir, expansion, max_cand); + candidate_counts[i] = count as i32; + let row_offset = i * max_cand; + for j in 0..max_cand { + candidate_indices[row_offset + j] = if j < candidates.len() { + candidates[j] as i32 + } else { + -1 + }; + } + } +} + +// --------------------------------------------------------------------------- +// PyCapsule exports +// --------------------------------------------------------------------------- + +#[pyfunction] +pub(crate) fn bvh_nearest_hit_capsule(py: Python<'_>) -> PyResult { + use std::ffi::c_void; + let fn_ptr: *mut c_void = ffi_bridge::BvhNearestHit as *mut c_void; + unsafe { + let capsule = pyo3::ffi::PyCapsule_New(fn_ptr, std::ptr::null(), None); + if capsule.is_null() { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "Failed to create PyCapsule for BvhNearestHit", + )); + } + Ok(PyObject::from_owned_ptr(py, capsule)) + } +} + +#[pyfunction] +pub(crate) fn bvh_get_candidates_capsule(py: Python<'_>) -> PyResult { + use std::ffi::c_void; + let fn_ptr: *mut c_void = ffi_bridge::BvhGetCandidates as *mut c_void; + unsafe { + let capsule = pyo3::ffi::PyCapsule_New(fn_ptr, std::ptr::null(), None); + if capsule.is_null() { + return Err(pyo3::exceptions::PyRuntimeError::new_err( + "Failed to create PyCapsule for BvhGetCandidates", + )); + } + Ok(PyObject::from_owned_ptr(py, capsule)) + } +} diff --git a/differt-core/src/accel/mod.rs b/differt-core/src/accel/mod.rs new file mode 100644 index 00000000..d4e646a3 --- /dev/null +++ b/differt-core/src/accel/mod.rs @@ -0,0 +1,11 @@ +use pyo3::{prelude::*, wrap_pymodule}; + +pub mod bvh; +pub mod ffi; + +#[cfg(not(tarpaulin_include))] +#[pymodule(gil_used = false)] +pub(crate) fn accel(m: Bound<'_, PyModule>) -> PyResult<()> { + m.add_wrapped(wrap_pymodule!(bvh::bvh))?; + Ok(()) +} diff --git a/differt-core/src/ffi.cc b/differt-core/src/ffi.cc new file mode 100644 index 00000000..f45a941a --- /dev/null +++ b/differt-core/src/ffi.cc @@ -0,0 +1,107 @@ +/// XLA FFI handlers for BVH acceleration. +/// +/// These thin C++ wrappers decode XLA buffers and call into Rust via cxx. +/// The actual computation happens in Rust (src/accel/bvh.rs). + +#include "differt-core/src/accel/ffi.rs.h" // cxx-generated bridge header +#include "ffi.h" + +namespace ffi = xla::ffi; + +// --- BvhNearestHit --- + +ffi::Error BvhNearestHitImpl(uint64_t bvh_id, + ffi::Buffer origins, + ffi::Buffer directions, + ffi::Buffer active_mask, + ffi::ResultBuffer hit_indices, + ffi::ResultBuffer hit_t) { + auto origins_dims = origins.dimensions(); + auto dirs_dims = directions.dimensions(); + + if (origins_dims.size() != 2 || origins_dims[1] != 3) { + return ffi::Error::InvalidArgument( + "BvhNearestHit: ray_origins must have shape [num_rays, 3]"); + } + if (dirs_dims.size() != 2 || dirs_dims[1] != 3) { + return ffi::Error::InvalidArgument( + "BvhNearestHit: ray_directions must have shape [num_rays, 3]"); + } + + int64_t num_rays = origins_dims[0]; + + // Active mask: shape [num_triangles] or [0] (empty means all active) + auto mask_dims = active_mask.dimensions(); + size_t mask_len = 1; + for (size_t i = 0; i < mask_dims.size(); ++i) { + mask_len *= static_cast(mask_dims[i]); + } + rust::Slice mask_slice{ + reinterpret_cast(active_mask.typed_data()), mask_len}; + + rust::Slice origins_slice{origins.typed_data(), + static_cast(num_rays * 3)}; + rust::Slice dirs_slice{directions.typed_data(), + static_cast(num_rays * 3)}; + rust::Slice indices_slice{(*hit_indices).typed_data(), + static_cast(num_rays)}; + rust::Slice t_slice{(*hit_t).typed_data(), + static_cast(num_rays)}; + + bvh_nearest_hit_ffi(bvh_id, origins_slice, dirs_slice, mask_slice, + indices_slice, t_slice); + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + BvhNearestHit, BvhNearestHitImpl, + ffi::Ffi::Bind() + .Attr("bvh_id") + .Arg>() // ray_origins + .Arg>() // ray_directions + .Arg>() // active_mask + .Ret>() // hit_indices + .Ret>()); // hit_t + +// --- BvhGetCandidates --- + +ffi::Error BvhGetCandidatesImpl(uint64_t bvh_id, float expansion, + int32_t max_candidates, + ffi::Buffer origins, + ffi::Buffer directions, + ffi::ResultBuffer candidate_indices, + ffi::ResultBuffer candidate_counts) { + auto origins_dims = origins.dimensions(); + + if (origins_dims.size() != 2 || origins_dims[1] != 3) { + return ffi::Error::InvalidArgument( + "BvhGetCandidates: ray_origins must have shape [num_rays, 3]"); + } + + int64_t num_rays = origins_dims[0]; + + rust::Slice origins_slice{origins.typed_data(), + static_cast(num_rays * 3)}; + rust::Slice dirs_slice{directions.typed_data(), + static_cast(num_rays * 3)}; + rust::Slice indices_slice{ + (*candidate_indices).typed_data(), + static_cast(num_rays * max_candidates)}; + rust::Slice counts_slice{(*candidate_counts).typed_data(), + static_cast(num_rays)}; + + bvh_get_candidates_ffi(bvh_id, expansion, max_candidates, origins_slice, + dirs_slice, indices_slice, counts_slice); + return ffi::Error::Success(); +} + +XLA_FFI_DEFINE_HANDLER_SYMBOL( + BvhGetCandidates, BvhGetCandidatesImpl, + ffi::Ffi::Bind() + .Attr("bvh_id") + .Attr("expansion") + .Attr("max_candidates") + .Arg>() // ray_origins + .Arg>() // ray_directions + .Ret>() // candidate_indices + .Ret>()); // candidate_counts diff --git a/differt-core/src/lib.rs b/differt-core/src/lib.rs index c93ae6fc..63f58ad0 100644 --- a/differt-core/src/lib.rs +++ b/differt-core/src/lib.rs @@ -1,5 +1,6 @@ use pyo3::{prelude::*, wrap_pymodule}; +pub mod accel; pub mod geometry; pub mod rt; pub mod scene; @@ -19,6 +20,7 @@ fn _differt_core(m: Bound<'_, PyModule>) -> PyResult<()> { env!("CARGO_PKG_VERSION_PATCH").parse::().unwrap(), ); m.add("__version_info__", version_info)?; + m.add_wrapped(wrap_pymodule!(accel::accel))?; m.add_wrapped(wrap_pymodule!(geometry::geometry))?; m.add_wrapped(wrap_pymodule!(rt::rt))?; m.add_wrapped(wrap_pymodule!(scene::scene))?; diff --git a/differt/src/differt/accel/__init__.py b/differt/src/differt/accel/__init__.py new file mode 100644 index 00000000..bba205e7 --- /dev/null +++ b/differt/src/differt/accel/__init__.py @@ -0,0 +1 @@ +"""Acceleration structures and techniques for ray tracing.""" diff --git a/differt/src/differt/accel/bvh/__init__.py b/differt/src/differt/accel/bvh/__init__.py new file mode 100644 index 00000000..af826c0f --- /dev/null +++ b/differt/src/differt/accel/bvh/__init__.py @@ -0,0 +1,45 @@ +""" +This module provides bounding volume hierarchy (BVH) acceleration for DiffeRT's ray-triangle intersection queries. + +.. warning:: + These APIs are experimental and may change in the near future. + See the discussion in (TODO: issue link) for details. + +Example: + Build a BVH over two triangles forming a quad, then shoot rays at it: + + >>> import jax.numpy as jnp + >>> from differt.accel.bvh import TriangleBvh + >>> verts = jnp.array( + ... [ + ... [[0, 0, 0], [1, 0, 0], [0, 1, 0]], + ... [[1, 1, 0], [1, 0, 0], [0, 1, 0]], + ... ], + ... dtype=jnp.float32, + ... ) + >>> bvh = TriangleBvh(verts) + >>> bvh.num_triangles + 2 + >>> origins = jnp.array([[0.25, 0.25, 1.0], [5.0, 5.0, 1.0]]) + >>> dirs = jnp.array([[0.0, 0.0, -1.0], [0.0, 0.0, -1.0]]) + >>> hit_idx, hit_t = bvh.nearest_hit(origins, dirs) + >>> hit_idx + array([ 0, -1], dtype=int32) + >>> hit_t + array([ 1., inf], dtype=float32) + +""" + +__all__ = ( + "TriangleBvh", + "bvh_first_triangles_hit_by_rays", + "bvh_rays_intersect_any_triangle", + "bvh_triangles_visible_from_vertices", +) + +from ._accelerated import ( + bvh_first_triangles_hit_by_rays, + bvh_rays_intersect_any_triangle, + bvh_triangles_visible_from_vertices, +) +from ._triangle_bvh import TriangleBvh diff --git a/differt/src/differt/accel/bvh/_accelerated.py b/differt/src/differt/accel/bvh/_accelerated.py new file mode 100644 index 00000000..d1587d7a --- /dev/null +++ b/differt/src/differt/accel/bvh/_accelerated.py @@ -0,0 +1,396 @@ +from typing import Any + +import jax.numpy as jnp +import numpy as np +from jaxtyping import Array, ArrayLike, Bool, Float, Int + +from differt.geometry._utils import ( # TODO: fixme, this it to avoid circular import + fibonacci_lattice, + viewing_frustum, +) +from differt.rt import rays_intersect_triangles +from differt.utils import smoothing_function + +from ._triangle_bvh import TriangleBvh, compute_expansion_radius + + +def bvh_rays_intersect_any_triangle( + ray_origins: Float[ArrayLike, "*#batch 3"], + ray_directions: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, + *, + hit_tol: Float[ArrayLike, ""] | None = None, + smoothing_factor: Float[ArrayLike, ""] | None = None, + bvh: TriangleBvh | None = None, + max_candidates: int = 512, + epsilon_grad: float = 1e-7, + **kwargs: Any, +) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: + r"""BVH-accelerated version of :func:`~differt.rt.rays_intersect_any_triangle`. + + When ``bvh`` is provided, uses BVH candidate selection to reduce the number + of triangles tested per ray from :math:`\mathcal{O}(N)` to :math:`\mathcal{O}(\log(N))`. + + Without smoothing (``smoothing_factor=None``), uses BVH nearest-hit to check + if any triangle blocks the ray. + + With smoothing (``smoothing_factor`` set), uses BVH with expanded boxes + to find candidate triangles, then runs the standard differentiable intersection + on candidates only. Gradients flow through the JAX intersection normally. + + Args: + ray_origins: An array of origin vertices. + ray_directions: An array of ray directions. + triangle_vertices: An array of triangle vertices. + active_triangles: Optional boolean mask for active triangles. + hit_tol: Tolerance for hit detection. + smoothing_factor: If set, uses smooth sigmoid approximations. + bvh: Pre-built BVH acceleration structure. + max_candidates: Maximum candidates per ray when smoothing is enabled. + epsilon_grad: Gradient truncation threshold for expansion radius. + kwargs: Keyword arguments passed to :func:`~differt.rt.rays_intersect_triangles`. + + Returns: + For each ray, whether it intersects with any of the triangles. + """ + if bvh is None: + from differt.rt._utils import rays_intersect_any_triangle # noqa: PLC0415 + + return rays_intersect_any_triangle( + ray_origins, + ray_directions, + triangle_vertices, + active_triangles, + hit_tol=hit_tol, + smoothing_factor=smoothing_factor, + **kwargs, + ) + + ray_origins_jnp = jnp.asarray(ray_origins) + ray_directions_jnp = jnp.asarray(ray_directions) + triangle_vertices_jnp = jnp.asarray(triangle_vertices) + + if hit_tol is None: + dtype = jnp.result_type( + ray_origins_jnp, ray_directions_jnp, triangle_vertices_jnp + ) + hit_tol = 10.0 * jnp.finfo(dtype).eps + + hit_threshold = 1.0 - jnp.asarray(hit_tol) + + # Compute batch shape from all inputs (matching brute-force broadcast semantics) + batch_shape = jnp.broadcast_shapes( + ray_origins_jnp.shape[:-1], + ray_directions_jnp.shape[:-1], + triangle_vertices_jnp.shape[:-3], + jnp.asarray(active_triangles).shape[:-1] + if active_triangles is not None + else (), + ) + ray_origins_jnp = jnp.broadcast_to(ray_origins_jnp, (*batch_shape, 3)) + ray_directions_jnp = jnp.broadcast_to(ray_directions_jnp, (*batch_shape, 3)) + + if smoothing_factor is None: + # No smoothing: use BVH nearest-hit as an "any" check. + # Pass active_triangles mask directly to Rust BVH so it skips + # inactive triangles and finds the nearest *active* hit. + flat_origins = np.asarray(ray_origins_jnp).reshape(-1, 3) + flat_dirs = np.asarray(ray_directions_jnp).reshape(-1, 3) + mask_np = None + if active_triangles is not None: + mask_np = np.ascontiguousarray(np.asarray(active_triangles).flatten()) + hit_indices, hit_t = bvh.nearest_hit(flat_origins, flat_dirs, mask_np) + + # Apply hit_threshold: only count hits with t < hit_threshold + any_hit = (hit_indices >= 0) & (hit_t < float(hit_threshold)) + + return jnp.asarray(any_hit.reshape(batch_shape)) + + # Smoothing/differentiable path: BVH candidate selection + JAX intersection + alpha = float(smoothing_factor) # type: ignore[arg-type] + + # Estimate triangle size for expansion radius + tri_np = np.asarray(triangle_vertices_jnp) + flat_tri = tri_np.reshape(-1, 3, 3) if tri_np.ndim > 3 else tri_np # noqa: PLR2004 + # Use mean edge length as characteristic size + edges = np.diff(flat_tri, axis=-2, append=flat_tri[..., :1, :]) + mean_tri_size = float(np.mean(np.linalg.norm(edges, axis=-1))) + expansion = compute_expansion_radius(alpha, mean_tri_size, epsilon_grad) + + # Check if expansion is too large (smoothing -> fallback to brute force) + scene_diag = float( + np.linalg.norm( + flat_tri.reshape(-1, 3).max(axis=0) - flat_tri.reshape(-1, 3).min(axis=0) + ) + ) + if expansion > scene_diag: + from differt.rt._utils import rays_intersect_any_triangle # noqa: PLC0415 + + return rays_intersect_any_triangle( + ray_origins, + ray_directions, + triangle_vertices, + active_triangles, + hit_tol=hit_tol, + smoothing_factor=smoothing_factor, + **kwargs, + ) + + # Get candidates from BVH (outside JIT -- returns numpy arrays) + flat_origins = np.asarray(ray_origins_jnp).reshape(-1, 3) + flat_dirs = np.asarray(ray_directions_jnp).reshape(-1, 3) + candidate_indices, candidate_counts = bvh.get_candidates( + flat_origins, flat_dirs, expansion, max_candidates + ) + + # If any ray has more candidates than max_candidates, fall back to brute force + # for correctness (truncation would give wrong gradients) + if np.any(candidate_counts > max_candidates): + import warnings # noqa: PLC0415 + + warnings.warn( + f"BVH candidate count ({int(candidate_counts.max())}) exceeds " + f"max_candidates ({max_candidates}). Falling back to brute force. " + f"Increase max_candidates or decrease smoothing_factor.", + stacklevel=2, + ) + from differt.rt._utils import rays_intersect_any_triangle # noqa: PLC0415 + + return rays_intersect_any_triangle( + ray_origins, + ray_directions, + triangle_vertices, + active_triangles, + hit_tol=hit_tol, + smoothing_factor=smoothing_factor, + **kwargs, + ) + + # Convert to JAX + cand_idx = jnp.asarray(candidate_indices.reshape(*batch_shape, max_candidates)) + cand_counts = jnp.asarray(candidate_counts.reshape(batch_shape)) + + # Gather candidate triangle vertices: shape [*batch, max_candidates, 3, 3] + # Use the non-batch triangle_vertices (first batch element if batched) + tri_flat = triangle_vertices_jnp + if tri_flat.ndim > 3: # noqa: PLR2004 + tri_flat = tri_flat.reshape(-1, 3, 3) + + # Clamp indices to valid range for gather (padding -1 -> 0, masked out later) + safe_idx = jnp.maximum(cand_idx, 0) + cand_verts = tri_flat[safe_idx] # [*batch, max_candidates, 3, 3] + + # Mask: which candidates are valid + arange = jnp.arange(max_candidates) + mask = ( + arange[None] < cand_counts[..., None] + if cand_counts.ndim == 1 + else arange < cand_counts[..., None] + ) + + # Active triangles filter + if active_triangles is not None: + active_jnp = jnp.asarray(active_triangles) + active_flat = active_jnp.reshape(-1) if active_jnp.ndim > 1 else active_jnp + cand_active = active_flat[safe_idx.reshape(-1, max_candidates)].reshape( + *batch_shape, max_candidates + ) + mask = mask & cand_active + + # Run differentiable intersection on candidates (pure JAX) + t, hit = rays_intersect_triangles( + ray_origins_jnp[..., None, :], # [*batch, 1, 3] + ray_directions_jnp[..., None, :], # [*batch, 1, 3] + cand_verts, # [*batch, max_candidates, 3, 3] + smoothing_factor=smoothing_factor, + **kwargs, + ) + + soft_hit = jnp.minimum(hit, smoothing_function(hit_threshold - t, smoothing_factor)) + return jnp.sum(soft_hit * mask, axis=-1).clip(max=1.0) + + +def bvh_triangles_visible_from_vertices( + vertices: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, + num_rays: int = int(1e6), + *, + bvh: TriangleBvh | None = None, + **kwargs: Any, +) -> Bool[Array, "*batch num_triangles"]: + r"""BVH-accelerated version of :func:`~differt.rt.triangles_visible_from_vertices`. + + Uses BVH nearest-hit for :math:`\mathcal{O}(\log(N))` per ray instead of :math:`\mathcal{O}(N)`, avoiding JAX's + :math:`\mathcal{O}(\text{rays} \cdot \text{triangles})` memory allocation. + + Args: + vertices: An array of vertices, used as origins of the rays. + triangle_vertices: An array of triangle vertices. + active_triangles: An optional boolean mask for active triangles. + num_rays: The number of rays to launch per vertex. + bvh: Pre-built BVH acceleration structure. + kwargs: Additional keyword arguments (for API compatibility). + + Returns: + For each triangle, whether it is visible from any of the rays. + """ + if bvh is None: + from differt.rt._utils import triangles_visible_from_vertices # noqa: PLC0415 + + return triangles_visible_from_vertices( + vertices, + triangle_vertices, + active_triangles, + num_rays=num_rays, + **kwargs, + ) + + vertices_jnp = jnp.asarray(vertices) + triangle_vertices_jnp = jnp.asarray(triangle_vertices) + num_triangles = triangle_vertices_jnp.shape[-3] + + # Compute batch shape from all inputs (matching brute-force broadcast semantics) + batch_shape = jnp.broadcast_shapes( + vertices_jnp.shape[:-1], + triangle_vertices_jnp.shape[:-3], + jnp.asarray(active_triangles).shape[:-1] + if active_triangles is not None + else (), + ) + vertices_jnp = jnp.broadcast_to(vertices_jnp, (*batch_shape, 3)) + + # Compute viewing frustum and generate fibonacci lattice directions + triangle_centers = triangle_vertices_jnp.mean(axis=-2, keepdims=True) + world_vertices = jnp.concat( + (triangle_vertices_jnp, triangle_centers), axis=-2 + ).reshape(*triangle_vertices_jnp.shape[:-3], -1, 3) + + if active_triangles is not None: + active_jnp = jnp.asarray(active_triangles) + active_vertices = jnp.repeat(active_jnp, 4, axis=-1) + else: + active_vertices = None + + ray_origins = vertices_jnp + + frustum = viewing_frustum( + ray_origins, + world_vertices, + active_vertices=active_vertices, + ) + + ray_directions = jnp.vectorize( + lambda n, frustum: fibonacci_lattice(n, frustum=frustum), + excluded={0}, + signature="(2,3)->(n,3)", + )(num_rays, frustum) + + # Flatten batch dims for BVH queries + flat_origins = np.asarray(ray_origins).reshape(-1, 3) + flat_dirs = np.asarray(ray_directions).reshape(-1, num_rays, 3) + num_vertices = flat_origins.shape[0] + + # Tile origins and flatten: each origin gets num_rays copies + # Shape: (num_vertices * num_rays, 3) + all_origins = np.repeat(flat_origins, num_rays, axis=0) + all_dirs = flat_dirs.reshape(-1, 3) + + # Ensure contiguous for Rust + all_origins = np.ascontiguousarray(all_origins) + all_dirs = np.ascontiguousarray(all_dirs) + + # Single BVH call for all rays + hit_indices, _ = bvh.nearest_hit(all_origins, all_dirs) + + # Reshape to (num_vertices, num_rays) + hit_indices = hit_indices.reshape(num_vertices, num_rays) + + # Build visibility mask + visible = np.zeros((*batch_shape, num_triangles), dtype=bool) + flat_visible = visible.reshape(num_vertices, num_triangles) + + active_np = None + if active_triangles is not None: + active_np = np.asarray(active_triangles) + if active_np.ndim > 1: + active_np = active_np.reshape(-1) + + for i in range(num_vertices): + valid = hit_indices[i] >= 0 + valid_indices = hit_indices[i][valid] + if active_np is not None: + active_mask = active_np[valid_indices] + valid_indices = valid_indices[active_mask] + unique_hits = np.unique(valid_indices) + flat_visible[i, unique_hits] = True + + return jnp.asarray(visible) + + +def bvh_first_triangles_hit_by_rays( + ray_origins: Float[ArrayLike, "*#batch 3"], + ray_directions: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, + *, + bvh: TriangleBvh | None = None, + **kwargs: Any, +) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: + r"""BVH-accelerated version of :func:`~differt.rt.first_triangles_hit_by_rays`. + + Uses BVH traversal for :math:`\mathcal{O}(\log(N))` nearest-hit per ray instead of :math:`\mathcal{O}(N)`. + + Args: + ray_origins: An array of origin vertices. + ray_directions: An array of ray directions. + triangle_vertices: An array of triangle vertices. + active_triangles: Optional boolean mask for active triangles. + bvh: Pre-built BVH acceleration structure. + kwargs: Additional keyword arguments (for API compatibility). + + Returns: + A tuple ``(indices, t)`` of the nearest triangle index and distance. + """ + if bvh is None: + from differt.rt._utils import first_triangles_hit_by_rays # noqa: PLC0415 + + return first_triangles_hit_by_rays( + ray_origins, + ray_directions, + triangle_vertices, + active_triangles, + **kwargs, + ) + + ray_origins_jnp = jnp.asarray(ray_origins) + ray_directions_jnp = jnp.asarray(ray_directions) + + # Compute batch shape from all inputs (matching brute-force broadcast semantics) + triangle_vertices_jnp = jnp.asarray(triangle_vertices) + batch_shape = jnp.broadcast_shapes( + ray_origins_jnp.shape[:-1], + ray_directions_jnp.shape[:-1], + triangle_vertices_jnp.shape[:-3], + jnp.asarray(active_triangles).shape[:-1] + if active_triangles is not None + else (), + ) + ray_origins_jnp = jnp.broadcast_to(ray_origins_jnp, (*batch_shape, 3)) + ray_directions_jnp = jnp.broadcast_to(ray_directions_jnp, (*batch_shape, 3)) + + flat_origins = np.asarray(ray_origins_jnp).reshape(-1, 3) + flat_dirs = np.asarray(ray_directions_jnp).reshape(-1, 3) + + # Pass active_triangles mask directly to Rust BVH so it skips inactive + # triangles during traversal and finds the nearest *active* hit. + mask_np = None + if active_triangles is not None: + mask_np = np.ascontiguousarray(np.asarray(active_triangles).flatten()) + hit_indices, hit_t = bvh.nearest_hit(flat_origins, flat_dirs, mask_np) + + return ( + jnp.asarray(hit_indices.reshape(batch_shape)), + jnp.asarray(hit_t.reshape(batch_shape)), + ) diff --git a/differt/src/differt/accel/bvh/_ffi.py b/differt/src/differt/accel/bvh/_ffi.py new file mode 100644 index 00000000..f6c0060b --- /dev/null +++ b/differt/src/differt/accel/bvh/_ffi.py @@ -0,0 +1,121 @@ +"""JAX FFI wrappers for BVH acceleration. + +These functions call into Rust BVH queries via XLA FFI, enabling BVH +operations inside JIT-compiled JAX functions (``jax.jit``, ``jax.lax.scan``). +""" + +__all__ = ( + "ffi_get_candidates", + "ffi_nearest_hit", +) + +from typing import Any + +import jax +import jax.numpy as jnp +import numpy as np +from jaxtyping import Array, Float + +import differt_core.accel.bvh as _differt_core_bvh + +bvh_nearest_hit_capsule = _differt_core_bvh.bvh_nearest_hit_capsule +bvh_get_candidates_capsule = _differt_core_bvh.bvh_get_candidates_capsule + +jax.ffi.register_ffi_target( + "bvh_nearest_hit", bvh_nearest_hit_capsule(), platform="cpu" +) +jax.ffi.register_ffi_target( + "bvh_get_candidates", bvh_get_candidates_capsule(), platform="cpu" +) + + +def ffi_nearest_hit( + ray_origins: Float[Array, "num_rays 3"], + ray_directions: Float[Array, "num_rays 3"], + *, + bvh_id: int, + active_mask: Array | None = None, # TODO: use better type annotation +) -> Any: + """BVH nearest-hit via XLA FFI. Works inside ``jax.jit``. + + Args: + ray_origins: Ray origins. + ray_directions: Ray directions. + bvh_id: Registry ID from ``bvh.register()``. + active_mask: Optional boolean mask for active triangles. + When provided, only triangles where the mask is ``True`` are + considered during traversal, correctly finding the nearest + *active* hit. + + Returns: + A tuple ``(hit_indices, hit_t)`` with triangle index (``-1`` for miss) + and parametric distance. + """ + num_rays = ray_origins.shape[0] + out_types = [ + jax.ShapeDtypeStruct((num_rays,), jnp.int32), # hit_indices + jax.ShapeDtypeStruct((num_rays,), jnp.float32), # hit_t + ] + + call = jax.ffi.ffi_call( + "bvh_nearest_hit", + out_types, + vmap_method="broadcast_all", + ) + + # Pass active_mask as a PRED buffer; empty array means no mask + if active_mask is None: + mask_buf = jnp.empty((0,), dtype=jnp.bool_) + else: + mask_buf = active_mask.astype(jnp.bool_) + + return call( + ray_origins.astype(jnp.float32), + ray_directions.astype(jnp.float32), + mask_buf, + bvh_id=np.uint64(bvh_id), + ) + + +def ffi_get_candidates( + ray_origins: Float[Array, "num_rays 3"], + ray_directions: Float[Array, "num_rays 3"], + *, + bvh_id: int, + expansion: float = 0.0, + max_candidates: int = 256, +) -> Any: + """BVH candidate selection via XLA FFI. Works inside ``jax.jit``. + + Args: + ray_origins: Ray origins. + ray_directions: Ray directions. + bvh_id: Registry ID from ``bvh.register()``. + expansion: Bounding box expansion for differentiable mode. + max_candidates: Maximum candidates per ray. + + Returns: + A tuple ``(candidate_indices, candidate_counts)`` where indices + are padded with ``-1`` and counts indicate valid entries. + """ + num_rays = ray_origins.shape[0] + out_types = [ + jax.ShapeDtypeStruct( + (num_rays, max_candidates), jnp.int32 + ), # candidate_indices + jax.ShapeDtypeStruct((num_rays,), jnp.int32), # candidate_counts + ] + + call = jax.ffi.ffi_call( + "bvh_get_candidates", + out_types, + vmap_method="broadcast_all", + ) + + return call( + ray_origins.astype(jnp.float32), + ray_directions.astype(jnp.float32), + bvh_id=np.uint64(bvh_id), + expansion=np.float32(expansion), + max_candidates=np.int32(max_candidates), + ) diff --git a/differt/src/differt/accel/bvh/_triangle_bvh.py b/differt/src/differt/accel/bvh/_triangle_bvh.py new file mode 100644 index 00000000..0ac44b15 --- /dev/null +++ b/differt/src/differt/accel/bvh/_triangle_bvh.py @@ -0,0 +1,201 @@ +__all__ = ("TriangleBvh",) + +import math + +import numpy as np +from jaxtyping import ArrayLike, Float + +from differt_core.accel.bvh import TriangleBvh as _RustBvh + + +class TriangleBvh: + """BVH acceleration structure for triangle meshes. + + Builds a SAH-based Bounding Volume Hierarchy over triangle vertices. + Supports two query types: + + - :meth:`nearest_hit`: find the closest triangle per ray (for SBR) + - :meth:`get_candidates`: find candidate triangles per ray with expanded + bounding boxes (for differentiable mode) + + Args: + triangle_vertices: Triangle vertices. + + Example: + >>> import jax.numpy as jnp + >>> from differt.accel.bvh import TriangleBvh + >>> verts = jnp.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32) + >>> bvh = TriangleBvh(verts) + >>> bvh.num_triangles + 1 + """ + + def __init__( + self, triangle_vertices: Float[ArrayLike, "num_triangles 3 3"] + ) -> None: + # TODO: why would we pass 2-dimension input? + verts = np.asarray(triangle_vertices, dtype=np.float32) + if verts.ndim == 3: # noqa: PLR2004 + # Shape (num_triangles, 3, 3) -> (num_triangles * 3, 3) + verts = verts.reshape(-1, 3) + self._inner = _RustBvh(verts) + + @property + def num_triangles(self) -> int: + """Number of triangles in the BVH.""" + return self._inner.num_triangles + + @property + def num_nodes(self) -> int: + """Number of BVH nodes used.""" + return self._inner.num_nodes + + def nearest_hit( + self, + ray_origins: ArrayLike, + ray_directions: ArrayLike, + active_mask: ArrayLike | None = None, + ) -> tuple[np.ndarray, np.ndarray]: + """Find the nearest active triangle hit by each ray. + + Args: + ray_origins: Ray origins with shape ``(num_rays, 3)``. + ray_directions: Ray directions with shape ``(num_rays, 3)``. + active_mask: Optional boolean mask with shape ``(num_triangles,)``. + When provided, only triangles where the mask is ``True`` are + considered during traversal. + + Returns: + A tuple ``(hit_indices, hit_t)`` where ``hit_indices`` has + shape ``(num_rays,)`` with triangle index (``-1`` for miss) + and ``hit_t`` has shape ``(num_rays,)`` with parametric distance. + + Example: + >>> import jax.numpy as jnp + >>> from differt.accel.bvh import TriangleBvh + >>> verts = jnp.array( + ... [[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32 + ... ) + >>> bvh = TriangleBvh(verts) + >>> origins = jnp.array([[0.1, 0.1, 1.0]]) + >>> dirs = jnp.array([[0.0, 0.0, -1.0]]) + >>> idx, t = bvh.nearest_hit(origins, dirs) + >>> int(idx[0]) + 0 + """ + origins = np.asarray(ray_origins, dtype=np.float32) + dirs = np.asarray(ray_directions, dtype=np.float32) + mask = None + if active_mask is not None: + mask = np.ascontiguousarray(np.asarray(active_mask).flatten()) + if origins.ndim > 2: # noqa: PLR2004 + orig_shape = origins.shape[:-1] + origins = origins.reshape(-1, 3) + dirs = dirs.reshape(-1, 3) + idx, t = self._inner.nearest_hit(origins, dirs, mask) + return idx.reshape(orig_shape), t.reshape(orig_shape) + return self._inner.nearest_hit(origins, dirs, mask) + + def register(self) -> int: + """Register this BVH for XLA FFI access. + + Returns: + Integer ID for use with JAX FFI handlers. + + Example: + >>> import jax.numpy as jnp + >>> from differt.accel.bvh import TriangleBvh + >>> verts = jnp.array( + ... [[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32 + ... ) + >>> bvh = TriangleBvh(verts) + >>> bvh_id = bvh.register() + >>> bvh_id > 0 + True + >>> bvh.unregister() + """ + return self._inner.register() + + def unregister(self) -> None: + """Remove this BVH from the global registry.""" + self._inner.unregister() + + def get_candidates( + self, + ray_origins: ArrayLike, + ray_directions: ArrayLike, + expansion: float = 0.0, + max_candidates: int = 256, + ) -> tuple[np.ndarray, np.ndarray]: + """Find candidate triangles whose expanded AABBs intersect each ray. + + For differentiable mode, the expansion captures all triangles with + non-negligible gradient contribution. + + Args: + ray_origins: Ray origins with shape ``(num_rays, 3)``. + ray_directions: Ray directions with shape ``(num_rays, 3)``. + expansion: Bounding box expansion amount. + max_candidates: Maximum candidates per ray. + + Returns: + A tuple ``(candidate_indices, candidate_counts)`` where + ``candidate_indices`` has shape ``(num_rays, max_candidates)`` + padded with ``-1``, and ``candidate_counts`` has shape ``(num_rays,)``. + + Example: + >>> import jax.numpy as jnp + >>> from differt.accel.bvh import TriangleBvh + >>> verts = jnp.array( + ... [[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32 + ... ) + >>> bvh = TriangleBvh(verts) + >>> origins = jnp.array([[0.1, 0.1, 1.0]]) + >>> dirs = jnp.array([[0.0, 0.0, -1.0]]) + >>> idx, counts = bvh.get_candidates(origins, dirs, expansion=0.0) + >>> int(counts[0]) >= 1 + True + """ + origins = np.asarray(ray_origins, dtype=np.float32) + dirs = np.asarray(ray_directions, dtype=np.float32) + if origins.ndim > 2: # noqa: PLR2004 + orig_shape = origins.shape[:-1] + origins = origins.reshape(-1, 3) + dirs = dirs.reshape(-1, 3) + idx, counts = self._inner.get_candidates( + origins, dirs, expansion, max_candidates + ) + return ( + idx.reshape(*orig_shape, max_candidates), + counts.reshape(orig_shape), + ) + return self._inner.get_candidates(origins, dirs, expansion, max_candidates) + + +def compute_expansion_radius( + smoothing_factor: float, + triangle_size: float = 1.0, + epsilon_grad: float = 1e-7, +) -> float: + """Compute BVH expansion radius for differentiable mode. + + The expansion guarantees that all triangles with gradient contribution + above ``epsilon_grad`` are included in the candidate set. + + Args: + smoothing_factor: The smoothing parameter (alpha). + triangle_size: Approximate characteristic triangle size. + epsilon_grad: Gradient truncation threshold. + + Returns: + The expansion radius in the same units as triangle_size. + + Example: + >>> from differt.accel.bvh import compute_expansion_radius + >>> r = compute_expansion_radius(10.0, triangle_size=1.0) + >>> r > 0 + True + """ + if smoothing_factor <= 0: + return float("inf") + return triangle_size * math.log(1.0 / epsilon_grad) / smoothing_factor diff --git a/differt/src/differt/geometry/_triangle_mesh.py b/differt/src/differt/geometry/_triangle_mesh.py index adc57dfc..a3ada6d0 100644 --- a/differt/src/differt/geometry/_triangle_mesh.py +++ b/differt/src/differt/geometry/_triangle_mesh.py @@ -19,6 +19,7 @@ from jaxtyping import Array, ArrayLike, Bool, Float, Int, PRNGKeyArray import differt_core.geometry +from differt.accel.bvh import TriangleBvh from differt.plotting import PlotOutput, draw_mesh, draw_paths, draw_rays, reuse from ._utils import normalize, orthogonal_basis, rotation_matrix_along_axis @@ -544,6 +545,31 @@ def triangle_vertices(self) -> Float[Array, "num_triangles 3 3"]: return jnp.take(self.vertices, self.triangles, axis=0) + def build_bvh(self) -> TriangleBvh: + """ + Build a :class:`~differt.accel.bvh.TriangleBvh` acceleration structure for this mesh. + + See :mod:`differt.accel.bvh` for more details about the BVH implementation. + + Returns: + A triangle BVH instance. + + Example: + >>> from differt.geometry import TriangleMesh + >>> import jax.numpy as jnp + >>> mesh = TriangleMesh( + ... vertices=jnp.array( + ... [[0, 0, 0], [1, 0, 0], [0, 1, 0]], dtype=jnp.float32 + ... ), + ... triangles=jnp.array([[0, 1, 2]]), + ... ) + >>> bvh = mesh.build_bvh() + >>> bvh.num_triangles + 1 + + """ + return TriangleBvh(self.triangle_vertices) + def set_assume_quads(self, flag: bool = True) -> Self: """ Return a new instance of this scene with :attr:`TriangleMesh.assume_quads` set to ``flag``. diff --git a/differt/src/differt/scene/_triangle_scene.py b/differt/src/differt/scene/_triangle_scene.py index 2aecd23a..392e761c 100644 --- a/differt/src/differt/scene/_triangle_scene.py +++ b/differt/src/differt/scene/_triangle_scene.py @@ -11,6 +11,7 @@ from jaxtyping import Array, ArrayLike, Bool, Float, Int import differt_core.scene +from differt.accel.bvh import TriangleBvh from differt.geometry import ( Paths, SBRPaths, @@ -66,6 +67,9 @@ def _compute_paths( smoothing_factor: Float[ArrayLike, " "], confidence_threshold: Float[ArrayLike, " "], batch_size: int | None, + bvh_id: int | None = ..., + bvh_expansion: float | None = ..., + bvh_max_candidates: int = ..., ) -> Paths[_F]: ... @@ -82,6 +86,9 @@ def _compute_paths( smoothing_factor: None, confidence_threshold: Float[ArrayLike, " "], batch_size: int | None, + bvh_id: int | None = ..., + bvh_expansion: float | None = ..., + bvh_max_candidates: int = ..., ) -> Paths[_B]: ... @@ -98,6 +105,9 @@ def _compute_paths( smoothing_factor: Float[ArrayLike, ""] | None, confidence_threshold: Float[ArrayLike, ""], batch_size: int | None, + bvh_id: int | None = None, + bvh_expansion: float | None = None, + bvh_max_candidates: int = 512, ) -> Paths[_M]: if min_len is None: dtype = jnp.result_type(mesh.vertices, tx_vertices, rx_vertices) @@ -249,7 +259,85 @@ def _compute_paths( # 3.3 - Identify paths that are blocked by other objects # [num_tx_vertices num_rx_vertices num_path_candidates] - if smoothing_factor is not None: + if bvh_id is not None and smoothing_factor is None: + # TODO: check whether we can avoid importing the hidden function here, just call a public function? Same for the next branch. + # BVH-accelerated blocking check (without smoothing, via XLA FFI) + from differt.accel.bvh._ffi import ffi_nearest_hit # noqa: PLC0415 + + batch_shape = ray_origins.shape[:-1] # [..., order+1] + flat_origins = ray_origins.reshape(-1, 3) + flat_dirs = ray_directions.reshape(-1, 3) + hit_idx, hit_t = ffi_nearest_hit( + flat_origins, + flat_dirs, + bvh_id=bvh_id, + active_mask=mesh.mask, + ) + # A ray is blocked if it hits something with t < 1 - hit_tol + hit_tol_val = ( + hit_tol + if hit_tol is not None + else 10.0 * jnp.finfo(jnp.result_type(ray_origins, ray_directions)).eps + ) + blocked_flat = (hit_idx >= 0) & (hit_t < (1.0 - hit_tol_val)) + blocked = blocked_flat.reshape(batch_shape).any(axis=-1) # Reduce on 'order' + elif ( + bvh_id is not None + and smoothing_factor is not None + and bvh_expansion is not None + ): + # BVH-accelerated blocking check with smoothing: candidate selection via FFI, + # differentiable intersection in JAX on the reduced candidate set. + from differt.accel.bvh._ffi import ffi_get_candidates # noqa: PLC0415 + + batch_shape = ray_origins.shape[:-1] + flat_origins = ray_origins.reshape(-1, 3) + flat_dirs = ray_directions.reshape(-1, 3) + + cand_idx, cand_counts = ffi_get_candidates( + flat_origins, + flat_dirs, + bvh_id=bvh_id, + expansion=bvh_expansion, + max_candidates=bvh_max_candidates, + ) + + # Gather candidate triangle vertices from the full mesh + tri_verts = mesh.triangle_vertices # [num_triangles, 3, 3] + safe_idx = jnp.maximum(cand_idx, 0) # clamp -1 padding -> 0 + cand_verts = tri_verts[safe_idx] # [N, max_cand, 3, 3] + + # Validity mask: which candidate slots are populated + arange = jnp.arange(bvh_max_candidates) + mask = arange[None, :] < cand_counts[:, None] + + # Active triangles filter + if mesh.mask is not None: + mask = mask & mesh.mask[safe_idx] + + # Hit threshold + hit_tol_val = ( + hit_tol + if hit_tol is not None + else 10.0 * jnp.finfo(jnp.result_type(ray_origins, ray_directions)).eps + ) + hit_threshold = 1.0 - jnp.asarray(hit_tol_val) + + # Differentiable intersection: broadcast rays [N,1,3] against candidates [N,max_cand,3,3] + t, hit = rays_intersect_triangles( + flat_origins[:, None, :], + flat_dirs[:, None, :], + cand_verts, + epsilon=epsilon, + smoothing_factor=smoothing_factor, + ) + + smoothed_hit = jnp.minimum( + hit, smoothing_function(hit_threshold - t, smoothing_factor) + ) + blocked_flat = jnp.sum(smoothed_hit * mask, axis=-1).clip(max=1.0) + blocked = blocked_flat.reshape(batch_shape).max(axis=-1, initial=0.0) + elif smoothing_factor is not None: blocked = rays_intersect_any_triangle( ray_origins, ray_directions, @@ -356,6 +444,7 @@ def _compute_paths_sbr( epsilon: Float[ArrayLike, ""] | None, max_dist: Float[ArrayLike, ""], batch_size: int | None, + bvh_id: int | None = None, ) -> SBRPaths: # TODO: type annotations for SBRPaths with mask dtype # 1 - Prepare arrays @@ -413,14 +502,29 @@ def scan_fun( # 1 - Compute next intersection with triangles # [num_tx_vertices num_rays] - triangles, t_hit = first_triangles_hit_by_rays( - ray_origins, - ray_directions, - triangle_vertices, - active_triangles=mesh.mask, - epsilon=epsilon, - batch_size=batch_size, - ) + if bvh_id is not None: + from differt.accel.bvh._ffi import ffi_nearest_hit # noqa: PLC0415 + + sbr_shape = ray_origins.shape[:-1] + flat_o = ray_origins.reshape(-1, 3) + flat_d = ray_directions.reshape(-1, 3) + idx, t = ffi_nearest_hit( + flat_o, + flat_d, + bvh_id=bvh_id, + active_mask=mesh.mask, + ) + triangles = idx.reshape(sbr_shape) + t_hit = t.reshape(sbr_shape) + else: + triangles, t_hit = first_triangles_hit_by_rays( + ray_origins, + ray_directions, + triangle_vertices, + active_triangles=mesh.mask, + epsilon=epsilon, + batch_size=batch_size, + ) # 2 - Check if the rays pass near RX @@ -808,6 +912,27 @@ def from_sionna(cls, sionna_scene: SionnaScene) -> Self: ), ) + def build_bvh(self) -> TriangleBvh: + """Build a BVH acceleration structure for the scene's triangle mesh. + + This delegates to :meth:`~differt.geometry.TriangleMesh.build_bvh`. + + See :mod:`differt.accel.bvh` for more details about the BVH implementation. + + Returns: + A triangle BVH instance. + + Example: + >>> from differt.scene import TriangleScene + >>> scene = TriangleScene.load_xml( + ... "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ... ) + >>> bvh = scene.build_bvh() + >>> bvh.num_triangles == scene.mesh.num_triangles + True + """ + return self.mesh.build_bvh() + @overload def compute_paths( self, @@ -825,6 +950,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Paths[_F]: ... @overload @@ -844,6 +970,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, " "] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Paths[_B]: ... @overload @@ -863,6 +990,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Paths[_F]: ... @overload @@ -882,6 +1010,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, " "] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Paths[_B]: ... @overload @@ -901,6 +1030,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> SizedIterator[Paths[_F]]: ... @overload @@ -920,6 +1050,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, " "] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> SizedIterator[Paths[_B]]: ... @overload @@ -939,6 +1070,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Iterator[Paths[_F]]: ... @overload @@ -958,6 +1090,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, " "] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Iterator[Paths[_B]]: ... @overload @@ -977,6 +1110,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Paths[_F]: ... @overload @@ -996,6 +1130,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, " "] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> Paths[_B]: ... @overload @@ -1015,6 +1150,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = ..., batch_size: int | None = ..., disconnect_inactive_triangles: bool = ..., + bvh: TriangleBvh | None = ..., ) -> SBRPaths: ... def compute_paths( @@ -1033,6 +1169,7 @@ def compute_paths( confidence_threshold: Float[ArrayLike, ""] = 0.5, batch_size: int | None = 512, disconnect_inactive_triangles: bool = False, + bvh: TriangleBvh | None = None, ) -> Paths[_M] | SizedIterator[Paths[_M]] | Iterator[Paths[_M]] | SBRPaths: """ Compute paths between all pairs of transmitters and receivers in the scene, that undergo a fixed number of interaction with objects. @@ -1146,6 +1283,18 @@ def compute_paths( For the ``'hybrid'`` method, inactive triangles are always disconnected regardless of this parameter value, as the method already depends on the mask. + bvh: An optional BVH acceleration structure from :meth:`build_bvh`. + + When provided, the BVH accelerates intersection queries: + + * ``'exhaustive'``: BVH accelerates the blocking check + via XLA FFI inside JIT. Without smoothing, uses nearest-hit. + With smoothing, uses candidate selection with differentiable + intersection on the reduced set. + * ``'hybrid'``: BVH accelerates both the visibility estimation + and the blocking check. + * ``'sbr'``: BVH replaces ``first_triangles_hit_by_rays`` in the + bounce loop (via XLA FFI inside ``lax.scan``). Returns: @@ -1184,6 +1333,34 @@ def compute_paths( tx_batch = self.transmitters.shape[:-1] rx_batch = self.receivers.shape[:-1] + # Extract BVH registry ID for FFI (if available) + bvh_id = None + bvh_expansion = None + if bvh is not None: + bvh_id = bvh.register() + # Compute expansion radius for smoothing BVH acceleration + # Requires XLA FFI (ffi_get_candidates) to work inside JIT + if bvh_id is not None and smoothing_factor is not None: + # TODO: maybe we should consider making this function public? + from differt.accel.bvh._triangle_bvh import ( # noqa: PLC0415 + compute_expansion_radius, + ) + + tri_np = np.asarray(self.mesh.triangle_vertices) + edges = np.diff(tri_np, axis=-2, append=tri_np[..., :1, :]) + mean_tri_size = float(np.mean(np.linalg.norm(edges, axis=-1))) + bvh_expansion = compute_expansion_radius( + float(np.asarray(smoothing_factor).real), mean_tri_size + ) + + # If expansion exceeds scene diagonal, BVH won't help + flat_pts = tri_np.reshape(-1, 3) + scene_diag = float( + np.linalg.norm(flat_pts.max(axis=0) - flat_pts.min(axis=0)) + ) + if bvh_expansion > scene_diag: + bvh_expansion = None + if method == "sbr": if order is None: msg = "Argument 'order' is required when 'method == \"sbr\"'." @@ -1198,6 +1375,7 @@ def compute_paths( epsilon=epsilon, max_dist=max_dist, batch_size=batch_size, + bvh_id=bvh_id, ).reshape(*tx_batch, *rx_batch, -1) # 0 - Constants arrays of chunks @@ -1215,23 +1393,44 @@ def compute_paths( msg = "Argument 'order' is required when 'method == \"hybrid\"'." raise ValueError(msg) - triangles_visible_from_tx = triangles_visible_from_vertices( - tx_vertices, - self.mesh.triangle_vertices, - active_triangles=self.mesh.mask, - num_rays=num_rays, - epsilon=epsilon, - batch_size=batch_size, - ).any(axis=0) # reduce on all transmitters + if bvh is not None: + from differt.accel.bvh._accelerated import ( # noqa: PLC0415 + bvh_triangles_visible_from_vertices, + ) - triangles_visible_from_rx = triangles_visible_from_vertices( - rx_vertices, - self.mesh.triangle_vertices, - active_triangles=self.mesh.mask, - num_rays=num_rays, - epsilon=epsilon, - batch_size=batch_size, - ).any(axis=0) # reduce on all receivers + triangles_visible_from_tx = bvh_triangles_visible_from_vertices( + tx_vertices, + self.mesh.triangle_vertices, + active_triangles=self.mesh.mask, + num_rays=num_rays, + bvh=bvh, + ).any(axis=0) # reduce on all transmitters + + triangles_visible_from_rx = bvh_triangles_visible_from_vertices( + rx_vertices, + self.mesh.triangle_vertices, + active_triangles=self.mesh.mask, + num_rays=num_rays, + bvh=bvh, + ).any(axis=0) # reduce on all receivers + else: + triangles_visible_from_tx = triangles_visible_from_vertices( + tx_vertices, + self.mesh.triangle_vertices, + active_triangles=self.mesh.mask, + num_rays=num_rays, + epsilon=epsilon, + batch_size=batch_size, + ).any(axis=0) # reduce on all transmitters + + triangles_visible_from_rx = triangles_visible_from_vertices( + rx_vertices, + self.mesh.triangle_vertices, + active_triangles=self.mesh.mask, + num_rays=num_rays, + epsilon=epsilon, + batch_size=batch_size, + ).any(axis=0) # reduce on all receivers if assume_quads: triangles_visible_from_tx = triangles_visible_from_tx.reshape( @@ -1292,6 +1491,9 @@ def compute_paths( smoothing_factor=smoothing_factor, confidence_threshold=confidence_threshold, batch_size=batch_size, + bvh_id=bvh_id, + bvh_expansion=bvh_expansion, + bvh_max_candidates=512, ).reshape(*tx_batch, *rx_batch, path_candidates.shape[0]) for path_candidates in path_candidates_iter ) @@ -1329,6 +1531,9 @@ def compute_paths( smoothing_factor=smoothing_factor, confidence_threshold=confidence_threshold, batch_size=batch_size, + bvh_id=bvh_id, + bvh_expansion=bvh_expansion, + bvh_max_candidates=512, ).reshape(*tx_batch, *rx_batch, path_candidates.shape[0]) def plot( diff --git a/differt/tests/accel/__init__.py b/differt/tests/accel/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/differt/tests/accel/test_bvh.py b/differt/tests/accel/test_bvh.py new file mode 100644 index 00000000..030c1dd5 --- /dev/null +++ b/differt/tests/accel/test_bvh.py @@ -0,0 +1,729 @@ +"""Tests for BVH acceleration structure. + +Validates that BVH-accelerated intersection queries produce the same results +as the brute-force implementations, for both non-smoothing and smoothing (differentiable) modes. +""" + +import sys + +import chex +import equinox as eqx +import jax +import jax.numpy as jnp +import numpy as np +import pytest + +from differt.accel.bvh import ( + TriangleBvh, + bvh_first_triangles_hit_by_rays, + bvh_rays_intersect_any_triangle, + bvh_triangles_visible_from_vertices, +) +from differt.accel.bvh._triangle_bvh import compute_expansion_radius +from differt.rt import ( + first_triangles_hit_by_rays, + rays_intersect_any_triangle, + triangles_visible_from_vertices, +) +from differt.scene import TriangleScene + + +@pytest.fixture +def single_triangle() -> jax.Array: + return jnp.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=jnp.float32) + + +@pytest.fixture +def three_triangles() -> jax.Array: + """Three triangles at different z-planes.""" + return jnp.array( + [ + [[0, 0, 0], [1, 0, 0], [0, 1, 0]], # z=0 + [[0, 0, 2], [1, 0, 2], [0, 1, 2]], # z=2 + [[5, 5, 0], [6, 5, 0], [5, 6, 0]], # far away + ], + dtype=jnp.float32, + ) + + +@pytest.fixture +def cube_scene() -> jax.Array: + """12-triangle unit cube.""" + faces = [ + ([0, 0, 1], [1, 0, 1], [1, 1, 1]), + ([0, 0, 1], [1, 1, 1], [0, 1, 1]), + ([0, 0, 0], [0, 1, 0], [1, 1, 0]), + ([0, 0, 0], [1, 1, 0], [1, 0, 0]), + ([0, 1, 0], [0, 1, 1], [1, 1, 1]), + ([0, 1, 0], [1, 1, 1], [1, 1, 0]), + ([0, 0, 0], [1, 0, 0], [1, 0, 1]), + ([0, 0, 0], [1, 0, 1], [0, 0, 1]), + ([1, 0, 0], [1, 1, 0], [1, 1, 1]), + ([1, 0, 0], [1, 1, 1], [1, 0, 1]), + ([0, 0, 0], [0, 0, 1], [0, 1, 1]), + ([0, 0, 0], [0, 1, 1], [0, 1, 0]), + ] + return jnp.array(faces, dtype=jnp.float32) + + +@pytest.fixture +def random_scene() -> jax.Array: + """50 random triangles in a 10x10x10 box.""" + key = jax.random.PRNGKey(42) + return jax.random.uniform(key, (50, 3, 3), minval=0.0, maxval=10.0) + + +class TestTriangleBvhConstruction: + def test_single_triangle(self, single_triangle: jax.Array) -> None: + bvh = TriangleBvh(single_triangle) + assert bvh.num_triangles == 1 + assert bvh.num_nodes >= 1 + + def test_cube(self, cube_scene: jax.Array) -> None: + bvh = TriangleBvh(cube_scene) + assert bvh.num_triangles == 12 + + def test_random(self, random_scene: jax.Array) -> None: + bvh = TriangleBvh(random_scene) + assert bvh.num_triangles == 50 + + def test_numpy_input(self) -> None: + verts = np.array([[[0, 0, 0], [1, 0, 0], [0, 1, 0]]], dtype=np.float32) + bvh = TriangleBvh(verts) + assert bvh.num_triangles == 1 + + +class TestNearestHit: + def test_single_triangle_hit(self, single_triangle: jax.Array) -> None: + bvh = TriangleBvh(single_triangle) + origins = jnp.array([[0.1, 0.1, 1.0]]) + dirs = jnp.array([[0.0, 0.0, -1.0]]) + + bvh_idx, bvh_t = bvh_first_triangles_hit_by_rays( + origins, dirs, single_triangle, bvh=bvh + ) + bf_idx, bf_t = first_triangles_hit_by_rays(origins, dirs, single_triangle) + + assert int(bvh_idx[0]) == int(bf_idx[0]) + chex.assert_trees_all_close(bvh_t[0], bf_t[0], atol=1e-4) + + def test_single_triangle_miss(self, single_triangle: jax.Array) -> None: + bvh = TriangleBvh(single_triangle) + origins = jnp.array([[0.1, 0.1, 1.0]]) + dirs = jnp.array([[0.0, 0.0, 1.0]]) # pointing away + + bvh_idx, _bvh_t = bvh_first_triangles_hit_by_rays( + origins, dirs, single_triangle, bvh=bvh + ) + bf_idx, _bf_t = first_triangles_hit_by_rays(origins, dirs, single_triangle) + + assert int(bvh_idx[0]) == int(bf_idx[0]) == -1 + + def test_cube_multiple_rays(self, cube_scene: jax.Array) -> None: + bvh = TriangleBvh(cube_scene) + origins = jnp.array( + [ + [0.5, 0.5, 2.0], + [0.5, 0.5, -1.0], + [2.0, 0.5, 0.5], + [0.5, 2.0, 0.5], + [5.0, 5.0, 5.0], # miss + ], + dtype=jnp.float32, + ) + dirs = jnp.array( + [ + [0.0, 0.0, -1.0], + [0.0, 0.0, 1.0], + [-1.0, 0.0, 0.0], + [0.0, -1.0, 0.0], + [0.0, 0.0, 1.0], # miss + ], + dtype=jnp.float32, + ) + + bvh_idx, bvh_t = bvh_first_triangles_hit_by_rays( + origins, dirs, cube_scene, bvh=bvh + ) + bf_idx, bf_t = first_triangles_hit_by_rays(origins, dirs, cube_scene) + + # Both should agree on hits vs misses + bvh_hit = bvh_idx >= 0 + bf_hit = bf_idx >= 0 + chex.assert_trees_all_equal(bvh_hit, bf_hit) + + # For hits, t values should match + for i in range(len(origins)): + if bvh_hit[i]: + chex.assert_trees_all_close(bvh_t[i], bf_t[i], atol=1e-4) + + def test_random_scene_many_rays(self, random_scene: jax.Array) -> None: + bvh = TriangleBvh(random_scene) + + key = jax.random.PRNGKey(123) + k1, k2 = jax.random.split(key) + origins = jax.random.uniform(k1, (100, 3), minval=-2.0, maxval=12.0) + dirs = jax.random.normal(k2, (100, 3)) + dirs = dirs / jnp.linalg.norm(dirs, axis=-1, keepdims=True) + + bvh_idx, bvh_t = bvh_first_triangles_hit_by_rays( + origins, dirs, random_scene, bvh=bvh + ) + bf_idx, bf_t = first_triangles_hit_by_rays(origins, dirs, random_scene) + + bvh_hit = bvh_idx >= 0 + bf_hit = bf_idx >= 0 + chex.assert_trees_all_equal(bvh_hit, bf_hit) + + hit_mask = bvh_hit & bf_hit + chex.assert_trees_all_close( + bvh_t[hit_mask], + bf_t[hit_mask], + atol=1e-4, + ) + + def test_fallback_without_bvh(self, single_triangle: jax.Array) -> None: + """Without bvh parameter, falls back to brute force.""" + origins = jnp.array([[0.1, 0.1, 1.0]]) + dirs = jnp.array([[0.0, 0.0, -1.0]]) + + idx, _t = bvh_first_triangles_hit_by_rays( + origins, dirs, single_triangle, bvh=None + ) + assert int(idx[0]) == 0 + + +class TestAnyIntersection: + def test_without_smoothing(self, three_triangles: jax.Array) -> None: + bvh = TriangleBvh(three_triangles) + # Ray from above, hits triangle at z=2 + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -1.0]]) + + bvh_any = bvh_rays_intersect_any_triangle( + origins, dirs, three_triangles, bvh=bvh + ) + bf_any = rays_intersect_any_triangle(origins, dirs, three_triangles) + + assert bool(bvh_any[0]) == bool(bf_any[0]) + + def test_without_smoothing_miss(self, three_triangles: jax.Array) -> None: + bvh = TriangleBvh(three_triangles) + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, 1.0]]) # pointing away + + bvh_any = bvh_rays_intersect_any_triangle( + origins, dirs, three_triangles, bvh=bvh + ) + bf_any = rays_intersect_any_triangle(origins, dirs, three_triangles) + + assert bool(bvh_any[0]) == bool(bf_any[0]) == False # noqa: E712 + + @pytest.mark.parametrize("smoothing_factor", [1.0, 10.0, 100.0]) + def test_with_smoothing_matches_brute_force( + self, three_triangles: jax.Array, smoothing_factor: float + ) -> None: + bvh = TriangleBvh(three_triangles) + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -1.0]]) + + bvh_smoothed = bvh_rays_intersect_any_triangle( + origins, + dirs, + three_triangles, + smoothing_factor=smoothing_factor, + bvh=bvh, + ) + bf_smoothed = rays_intersect_any_triangle( + origins, + dirs, + three_triangles, + smoothing_factor=smoothing_factor, + ) + + chex.assert_trees_all_close(bvh_smoothed[0], bf_smoothed[0], atol=1e-3) + + def test_with_smoothing_random_scene(self, random_scene: jax.Array) -> None: + bvh = TriangleBvh(random_scene) + key = jax.random.PRNGKey(456) + k1, k2 = jax.random.split(key) + origins = jax.random.uniform(k1, (20, 3), minval=0.0, maxval=10.0) + dirs = jax.random.normal(k2, (20, 3)) + dirs = dirs / jnp.linalg.norm(dirs, axis=-1, keepdims=True) + + bvh_smoothed = bvh_rays_intersect_any_triangle( + origins, + dirs, + random_scene, + smoothing_factor=10.0, + bvh=bvh, + max_candidates=256, + ) + bf_smoothed = rays_intersect_any_triangle( + origins, dirs, random_scene, smoothing_factor=10.0 + ) + + chex.assert_trees_all_close(bvh_smoothed, bf_smoothed, atol=1e-2) + + def test_fallback_without_bvh(self, three_triangles: jax.Array) -> None: + # Ray from z=3 to z=-2 (length 5), triangle at z=2 is at t=0.2 + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -5.0]]) + + result = bvh_rays_intersect_any_triangle( + origins, dirs, three_triangles, bvh=None + ) + assert bool(result[0]) + + +class TestExpansionRadius: + def test_positive(self) -> None: + r = compute_expansion_radius(10.0, 1.0, 1e-7) + assert r > 0 + + def test_decreases_with_smoothing(self) -> None: + r1 = compute_expansion_radius(1.0, 1.0, 1e-7) + r2 = compute_expansion_radius(10.0, 1.0, 1e-7) + r3 = compute_expansion_radius(100.0, 1.0, 1e-7) + assert r1 > r2 > r3 + + def test_scales_with_triangle_size(self) -> None: + r1 = compute_expansion_radius(10.0, 1.0, 1e-7) + r2 = compute_expansion_radius(10.0, 2.0, 1e-7) + np.testing.assert_allclose(r2, 2 * r1) + + def test_zero_smoothing(self) -> None: + r = compute_expansion_radius(0.0, 1.0, 1e-7) + assert r == float("inf") + + +class TestVisibility: + def test_single_triangle_visible(self, single_triangle: jax.Array) -> None: + bvh = TriangleBvh(single_triangle) + origin = jnp.array([0.3, 0.3, 1.0]) + + bvh_vis = bvh_triangles_visible_from_vertices( + origin, single_triangle, bvh=bvh, num_rays=1000 + ) + assert bool(bvh_vis[0]) # triangle is visible from above + + def test_cube_all_visible(self, cube_scene: jax.Array) -> None: + bvh = TriangleBvh(cube_scene) + origin = jnp.array([0.5, 0.5, 2.0]) # above the cube + + bvh_vis = bvh_triangles_visible_from_vertices( + origin, cube_scene, bvh=bvh, num_rays=10000 + ) + # From above, the top face triangles should be visible + assert int(bvh_vis.sum()) >= 2 # at least the top face + + def test_matches_brute_force(self, cube_scene: jax.Array) -> None: + bvh = TriangleBvh(cube_scene) + origin = jnp.array([0.5, 0.5, 2.0]) + + bvh_vis = bvh_triangles_visible_from_vertices( + origin, cube_scene, bvh=bvh, num_rays=10000 + ) + bf_vis = triangles_visible_from_vertices(origin, cube_scene, num_rays=10000) + + # Both should see approximately the same set (statistical) + bvh_count = int(bvh_vis.sum()) + bf_count = int(bf_vis.sum()) + assert abs(bvh_count - bf_count) <= 2 # allow small difference + + def test_fallback_without_bvh(self, single_triangle: jax.Array) -> None: + origin = jnp.array([0.3, 0.3, 1.0]) + vis = bvh_triangles_visible_from_vertices( + origin, single_triangle, bvh=None, num_rays=1000 + ) + assert bool(vis[0]) + + def test_multiple_origins(self, cube_scene: jax.Array) -> None: + bvh = TriangleBvh(cube_scene) + origins = jnp.array([ + [0.5, 0.5, 2.0], # above + [0.5, 0.5, -1.0], # below + ]) + + bvh_vis = bvh_triangles_visible_from_vertices( + origins, cube_scene, bvh=bvh, num_rays=10000 + ) + assert bvh_vis.shape == (2, 12) + assert int(bvh_vis[0].sum()) >= 2 # top visible + assert int(bvh_vis[1].sum()) >= 2 # bottom visible + + +class TestComputePathsBvh: + # TODO: use get_sionna_scene (fixture is available) + def test_hybrid_with_bvh(self) -> None: + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + scene = eqx.tree_at( + lambda s: s.transmitters, scene, jnp.array([[0.5, 0.5, 1.0]]) + ) + scene = eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[-0.5, 0.5, 0.5]])) + bvh = scene.build_bvh() + + paths_bvh = scene.compute_paths(order=1, method="hybrid", bvh=bvh) + paths_bf = scene.compute_paths(order=1, method="hybrid") + + # Both should find the same valid paths + assert paths_bvh.mask.shape == paths_bf.mask.shape + chex.assert_trees_all_equal(paths_bvh.mask, paths_bf.mask) + + def test_exhaustive_with_bvh_ffi(self) -> None: + """Exhaustive method uses BVH FFI for blocking check.""" + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + scene = eqx.tree_at( + lambda s: s.transmitters, scene, jnp.array([[0.5, 0.5, 1.0]]) + ) + scene = eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[-0.5, 0.5, 0.5]])) + bvh = scene.build_bvh() + + # BVH should give same results as brute force without smoothing + paths_bvh = scene.compute_paths(order=1, method="exhaustive", bvh=bvh) + paths_bf = scene.compute_paths(order=1, method="exhaustive") + + chex.assert_trees_all_equal(paths_bvh.mask, paths_bf.mask) + + def test_sbr_with_bvh_ffi(self) -> None: + """SBR method uses BVH FFI in the bounce loop.""" + scene = TriangleScene.load_xml("differt/src/differt/scene/scenes/box/box.xml") + scene = eqx.tree_at( + lambda s: s.transmitters, scene, jnp.array([[0.5, 0.5, 2.0]]) + ) + scene = eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[0.5, 0.5, -1.0]])) + bvh = scene.build_bvh() + + paths_bvh = scene.compute_paths(order=1, method="sbr", bvh=bvh, num_rays=1000) + paths_bf = scene.compute_paths(order=1, method="sbr", num_rays=1000) + # Both should produce SBRPaths with same shape and mostly matching data. + # Small differences are expected due to different Moller-Trumbore epsilon + # (BVH uses 1e-8, brute-force uses ~1.2e-6 for f32). + assert type(paths_bvh).__name__ == "SBRPaths" + assert type(paths_bf).__name__ == "SBRPaths" + assert paths_bvh.vertices.shape == paths_bf.vertices.shape + assert paths_bvh.objects.shape == paths_bf.objects.shape + # Most object indices should agree (allow small fraction to differ) + objs_bvh = np.asarray(paths_bvh.objects) + objs_bf = np.asarray(paths_bf.objects) + match_frac = np.mean(objs_bvh == objs_bf) + assert match_frac > 0.95, f"Object indices match only {match_frac:.1%}" + + def test_exhaustive_matches_without_bvh(self) -> None: + """Exhaustive with BVH produces same results as without.""" + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + scene = eqx.tree_at( + lambda s: s.transmitters, scene, jnp.array([[0.5, 0.5, 1.0]]) + ) + scene = eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[-0.5, 0.5, 0.5]])) + bvh = scene.build_bvh() + + paths_bvh = scene.compute_paths(order=1, method="exhaustive", bvh=bvh) + paths_bf = scene.compute_paths(order=1, method="exhaustive") + + chex.assert_trees_all_equal(paths_bvh.mask, paths_bf.mask) + + +class TestBatchedRays: + def test_nearest_hit_3d_origins(self, single_triangle: jax.Array) -> None: + """nearest_hit with ndim > 2 triggers the reshape branch.""" + bvh = TriangleBvh(single_triangle) + # Shape (2, 1, 3) -- batch dimension + origins = jnp.array([[[0.1, 0.1, 1.0]], [[0.1, 0.1, 1.0]]], dtype=jnp.float32) + dirs = jnp.array([[[0.0, 0.0, -1.0]], [[0.0, 0.0, 1.0]]], dtype=jnp.float32) + idx, _t = bvh.nearest_hit(origins, dirs) + assert idx.shape == (2, 1) + assert int(idx[0, 0]) == 0 # hit + assert int(idx[1, 0]) == -1 # miss (pointing away) + + def test_get_candidates_3d_origins(self, single_triangle: jax.Array) -> None: + """get_candidates with ndim > 2 triggers the reshape branch.""" + bvh = TriangleBvh(single_triangle) + origins = jnp.array([[[0.1, 0.1, 1.0]], [[5.0, 5.0, 1.0]]], dtype=jnp.float32) + dirs = jnp.array([[[0.0, 0.0, -1.0]], [[0.0, 0.0, -1.0]]], dtype=jnp.float32) + max_cands = 8 + idx, counts = bvh.get_candidates( + origins, dirs, expansion=0.0, max_candidates=max_cands + ) + assert idx.shape == (2, 1, max_cands) + assert counts.shape == (2, 1) + assert int(counts[0, 0]) >= 1 # near triangle + assert int(counts[1, 0]) == 0 # far away, no candidates + + +class TestExpansionRadiusEdgeCases: + def test_negative_smoothing(self) -> None: + r = compute_expansion_radius(-5.0, 1.0, 1e-7) + assert r == float("inf") + + +class TestAcceleratedBranches: + def test_smoothing_large_expansion_fallback( + self, three_triangles: jax.Array + ) -> None: + """Very small smoothing_factor -> huge expansion -> brute-force fallback.""" + bvh = TriangleBvh(three_triangles) + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -1.0]]) + + # smoothing_factor=0.001 produces expansion >> scene diagonal + result = bvh_rays_intersect_any_triangle( + origins, dirs, three_triangles, smoothing_factor=0.001, bvh=bvh + ) + bf_result = rays_intersect_any_triangle( + origins, dirs, three_triangles, smoothing_factor=0.001 + ) + chex.assert_trees_all_close(result[0], bf_result[0], atol=1e-3) + + def test_smoothing_max_candidates_exceeded(self, random_scene: jax.Array) -> None: + """max_candidates=1 with many overlapping triangles -> warning + fallback.""" + bvh = TriangleBvh(random_scene) + key = jax.random.PRNGKey(789) + k1, k2 = jax.random.split(key) + origins = jax.random.uniform(k1, (10, 3), minval=0.0, maxval=10.0) + dirs = jax.random.normal(k2, (10, 3)) + dirs = dirs / jnp.linalg.norm(dirs, axis=-1, keepdims=True) + + with pytest.warns(UserWarning, match="BVH candidate count"): + result = bvh_rays_intersect_any_triangle( + origins, + dirs, + random_scene, + smoothing_factor=100.0, + bvh=bvh, + max_candidates=1, + ) + assert result.shape == (10,) + + def test_without_smoothing_active_triangles( + self, three_triangles: jax.Array + ) -> None: + """active_triangles mask without smoothing for bvh_rays_intersect_any_triangle.""" + bvh = TriangleBvh(three_triangles) + # Ray from z=3 pointing down with length 5 (t < 1 for triangles at z=2 and z=0) + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -5.0]]) + + # All active: should hit + active_all = jnp.array([True, True, True]) + result_all = bvh_rays_intersect_any_triangle( + origins, dirs, three_triangles, active_triangles=active_all, bvh=bvh + ) + assert bool(result_all[0]) + + # Only the far-away triangle active: should miss + active_far = jnp.array([False, False, True]) + result_far = bvh_rays_intersect_any_triangle( + origins, dirs, three_triangles, active_triangles=active_far, bvh=bvh + ) + assert not bool(result_far[0]) + + def test_with_smoothing_active_triangles(self, three_triangles: jax.Array) -> None: + """active_triangles mask with smoothing for bvh_rays_intersect_any_triangle.""" + bvh = TriangleBvh(three_triangles) + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -1.0]]) + + active = jnp.array([True, True, False]) + result = bvh_rays_intersect_any_triangle( + origins, + dirs, + three_triangles, + active_triangles=active, + smoothing_factor=100.0, + bvh=bvh, + ) + assert result.shape == (1,) + assert float(result[0]) > 0 # should detect the hit + + def test_first_hit_with_active_triangles(self, three_triangles: jax.Array) -> None: + """active_triangles mask for bvh_first_triangles_hit_by_rays.""" + bvh = TriangleBvh(three_triangles) + # Ray from z=3 pointing down: nearest active hit changes with mask + origins = jnp.array([[0.1, 0.1, 3.0]]) + dirs = jnp.array([[0.0, 0.0, -5.0]]) # length 5 so t < 1 for all hits + + # Only triangle 0 (z=0) active, triangle 1 (z=2) inactive + active = jnp.array([True, False, True]) + idx, t = bvh_first_triangles_hit_by_rays( + origins, dirs, three_triangles, active_triangles=active, bvh=bvh + ) + assert int(idx[0]) == 0 # nearest active is z=0 + chex.assert_trees_all_close(t[0], jnp.array(0.6), atol=1e-4) + + def test_visibility_with_active_triangles(self, single_triangle: jax.Array) -> None: + """active_triangles mask for bvh_triangles_visible_from_vertices.""" + bvh = TriangleBvh(single_triangle) + origin = jnp.array([0.3, 0.3, 1.0]) + + active = jnp.array([True]) + vis_active = bvh_triangles_visible_from_vertices( + origin, single_triangle, active_triangles=active, bvh=bvh, num_rays=1000 + ) + assert bool(vis_active[0]) + + inactive = jnp.array([False]) + vis_inactive = bvh_triangles_visible_from_vertices( + origin, single_triangle, active_triangles=inactive, bvh=bvh, num_rays=1000 + ) + assert not bool(vis_inactive[0]) + + +class TestFfiRegistration: + # TODO: update this test, as we always register the FFI at import time now, so the monkeypatching may not work as intended / maybe this test is no longer relevant. + def test_ensure_registered_import_error( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Missing xla-ffi feature raises ImportError with helpful message.""" + import differt.accel.bvh._ffi as ffi_mod # noqa: PLC0415 + + monkeypatch.setattr(ffi_mod, "_FFI_REGISTERED", False) + + # Remove differt_core._differt_core from sys.modules so the import fails + monkeypatch.delitem(sys.modules, "differt_core._differt_core", raising=False) + monkeypatch.setitem(sys.modules, "differt_core._differt_core", None) + + with pytest.raises(ImportError, match="BVH XLA FFI not available"): + ffi_mod._ensure_registered() # noqa: SLF001 + + +class TestBuildBvh: + def test_build_bvh_from_scene(self) -> None: + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + bvh = scene.build_bvh() + assert bvh.num_triangles == scene.mesh.num_triangles + assert bvh.num_nodes >= 1 + + +class TestSmoothingBvhComputePaths: + """Test differentiable BVH acceleration in compute_paths.""" + + @staticmethod + def _make_scene() -> TriangleScene: + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + scene = eqx.tree_at( + lambda s: s.transmitters, scene, jnp.array([[0.5, 0.5, 1.0]]) + ) + return eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[-0.5, 0.5, 0.5]])) + + def test_smoothing_bvh_matches_brute_force(self) -> None: + """Smoothing mode with BVH should match brute force within tolerance.""" + scene = self._make_scene() + bvh = scene.build_bvh() + + for sf in [1.0, 10.0, 100.0]: + paths_bvh = scene.compute_paths(order=1, smoothing_factor=sf, bvh=bvh) + paths_bf = scene.compute_paths(order=1, smoothing_factor=sf) + + chex.assert_trees_all_close( + paths_bvh.mask, + paths_bf.mask, + atol=1e-3, + ) + + def test_smoothing_bvh_gradient_flow(self) -> None: + """Gradients should flow through the smoothing BVH path.""" + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + tx = jnp.array([[0.5, 0.5, 1.0]]) + scene = eqx.tree_at(lambda s: s.transmitters, scene, tx) + scene = eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[-0.5, 0.5, 0.5]])) + bvh = scene.build_bvh() + + def loss_fn(tx_pos: jax.Array) -> jax.Array: + s = eqx.tree_at(lambda s: s.transmitters, scene, tx_pos) + paths = s.compute_paths(order=1, smoothing_factor=10.0, bvh=bvh) + return jnp.sum(paths.mask) + + grad = jax.grad(loss_fn)(tx) + assert jnp.all(jnp.isfinite(grad)), f"Non-finite gradients: {grad}" + + def test_smoothing_bvh_expansion_fallback(self) -> None: + """Very small smoothing_factor should fall back to brute force.""" + scene = self._make_scene() + bvh = scene.build_bvh() + + # smoothing_factor=0.001 produces huge expansion > scene diagonal + paths_bvh = scene.compute_paths(order=1, smoothing_factor=0.001, bvh=bvh) + paths_bf = scene.compute_paths(order=1, smoothing_factor=0.001) + + chex.assert_trees_all_close( + paths_bvh.mask, + paths_bf.mask, + atol=1e-6, + ) + + +class TestSmoothingBvhBranchCoverage: + """Exercise the smoothing BVH branch in _compute_paths with mocked FFI.""" + + @staticmethod + def _make_scene() -> TriangleScene: + scene = TriangleScene.load_xml( + "differt/src/differt/scene/scenes/simple_reflector/simple_reflector.xml" + ) + scene = eqx.tree_at( + lambda s: s.transmitters, scene, jnp.array([[0.5, 0.5, 1.0]]) + ) + return eqx.tree_at(lambda s: s.receivers, scene, jnp.array([[-0.5, 0.5, 0.5]])) + + def test_smoothing_bvh_branch_with_mocked_ffi( + self, monkeypatch: pytest.MonkeyPatch + ) -> None: + """Mock FFI to exercise the smoothing BVH code path for coverage.""" + import differt.accel.bvh._ffi as ffi_mod # noqa: PLC0415 + + scene = self._make_scene() + bvh = scene.build_bvh() + + # Create a mock ffi_get_candidates using jax.pure_callback + # so it works inside JIT + def mock_ffi_get_candidates( + ray_origins: jax.Array, + ray_directions: jax.Array, + **kwargs: object, + ) -> tuple[jax.Array, jax.Array]: + expansion = float(kwargs.get("expansion", 0.0)) # type: ignore[arg-type] + max_candidates = int(kwargs.get("max_candidates", 256)) # type: ignore[arg-type] + + def _callback( + origins: jax.Array, dirs: jax.Array + ) -> tuple[np.ndarray, np.ndarray]: + idx, counts = bvh.get_candidates( + np.asarray(origins), np.asarray(dirs), expansion, max_candidates + ) + return np.asarray(idx, dtype=np.int32), np.asarray( + counts, dtype=np.int32 + ) + + num_rays = ray_origins.shape[0] + result_shapes = ( + jax.ShapeDtypeStruct((num_rays, max_candidates), jnp.int32), + jax.ShapeDtypeStruct((num_rays,), jnp.int32), + ) + return jax.pure_callback( + _callback, result_shapes, ray_origins, ray_directions + ) + + monkeypatch.setattr(ffi_mod, "_ensure_registered", lambda: None) + monkeypatch.setattr(ffi_mod, "ffi_get_candidates", mock_ffi_get_candidates) + + paths_bvh = scene.compute_paths(order=1, smoothing_factor=10.0, bvh=bvh) + paths_bf = scene.compute_paths(order=1, smoothing_factor=10.0) + + chex.assert_trees_all_close( + paths_bvh.mask, + paths_bf.mask, + atol=1e-3, + ) diff --git a/differt/tests/benchmarks/test_rt.py b/differt/tests/benchmarks/test_rt.py index 3d4459da..7d7a1b20 100644 --- a/differt/tests/benchmarks/test_rt.py +++ b/differt/tests/benchmarks/test_rt.py @@ -7,6 +7,12 @@ from jaxtyping import Array, PRNGKeyArray from pytest_codspeed import BenchmarkFixture +from differt.accel.bvh import ( + TriangleBvh, + bvh_first_triangles_hit_by_rays, + bvh_rays_intersect_any_triangle, + bvh_triangles_visible_from_vertices, +) from differt.geometry import fibonacci_lattice from differt.rt import ( fermat_path_on_planar_mirrors, @@ -195,3 +201,85 @@ def bench_fun() -> Array: bench_fun() benchmark(bench_fun) + + +# --------------------------------------------------------------------------- +# BVH-accelerated benchmarks (PyO3 path, no FFI needed) +# --------------------------------------------------------------------------- + + +@pytest.mark.benchmark(group="rays_intersect_any_triangle_bvh") +def test_rays_intersect_any_triangle_bvh( + simple_street_canyon_scene: TriangleScene, + benchmark: BenchmarkFixture, + key: PRNGKeyArray, +) -> None: + scene = random_scene(simple_street_canyon_scene, key=key) + bvh = TriangleBvh(scene.mesh.triangle_vertices) + + ray_origins = scene.transmitters + ray_directions = fibonacci_lattice(1_000_000) + + @jax.block_until_ready + def bench_fun() -> Array: + return bvh_rays_intersect_any_triangle( + ray_origins, + ray_directions, + scene.mesh.triangle_vertices, + active_triangles=scene.mesh.mask, + bvh=bvh, + ) + + bench_fun() + + benchmark(bench_fun) + + +@pytest.mark.benchmark(group="first_triangles_hit_by_rays_bvh") +def test_first_triangles_hit_by_rays_bvh( + simple_street_canyon_scene: TriangleScene, + benchmark: BenchmarkFixture, + key: PRNGKeyArray, +) -> None: + scene = random_scene(simple_street_canyon_scene, key=key) + bvh = TriangleBvh(scene.mesh.triangle_vertices) + + ray_origins = scene.transmitters + ray_directions = fibonacci_lattice(1_000_000) + + @jax.block_until_ready + def bench_fun() -> tuple[Array, Array]: + return bvh_first_triangles_hit_by_rays( + ray_origins, + ray_directions, + scene.mesh.triangle_vertices, + active_triangles=scene.mesh.mask, + bvh=bvh, + ) + + bench_fun() + + benchmark(bench_fun) + + +@pytest.mark.benchmark(group="triangles_visible_from_vertices_bvh") +def test_transmitter_visibility_bvh( + simple_street_canyon_scene: TriangleScene, + benchmark: BenchmarkFixture, + key: PRNGKeyArray, +) -> None: + scene = random_scene(simple_street_canyon_scene, key=key) + bvh = TriangleBvh(scene.mesh.triangle_vertices) + + @jax.block_until_ready + def bench_fun() -> Array: + return bvh_triangles_visible_from_vertices( + scene.transmitters, + scene.mesh.triangle_vertices, + active_triangles=scene.mesh.mask, + bvh=bvh, + ) + + bench_fun() + + benchmark(bench_fun) diff --git a/docs/source/conf.py b/docs/source/conf.py index 3ef99d87..01dcd7a9 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -77,6 +77,7 @@ ("py:obj", "differt.rt.utils._T"), ("py:obj", "None.ArrayType"), ("py:class", "setup..ArrayType"), + ("py:class", "setup..ArrayLike"), ) linkcheck_ignore = ["https://doi.org/10.1002/2015RS005659"] diff --git a/docs/source/reference/differt.accel.bvh.rst b/docs/source/reference/differt.accel.bvh.rst new file mode 100644 index 00000000..bf0d5aef --- /dev/null +++ b/docs/source/reference/differt.accel.bvh.rst @@ -0,0 +1,34 @@ +``differt.accel.bvh`` module +============================ + +.. currentmodule:: differt.accel.bvh + +.. automodule:: differt.accel.bvh + +.. rubric:: BVH acceleration structure + +BVH acceleration structure wrapping the Rust implementation. + +Provides :class:`TriangleBvh` for accelerating ray-triangle intersection queries +from :math:`\mathcal{O}(\text{rays} \cdot \text{triangles})` to :math:`\mathcal{O}(\text{rays} \cdot \log(\text{triangles}))`. + +.. autosummary:: + :toctree: _autosummary + + TriangleBvh + +.. rubric:: BVH-accelerated intersection functions + +Drop-in replacements for the functions in :mod:`differt.rt`, +accelerated by a BVH for :math:`\mathcal{O}(\text{rays} \cdot \log(\text{triangles}))` instead of :math:`\mathcal{O}(\text{rays} \cdot \text{triangles})`. + +Without smoothing (``smoothing_factor=None``), the BVH does the full intersection. +With smoothing (``smoothing_factor`` set), the BVH selects candidates and the +existing JAX-based Möller-Trumbore runs on the reduced set. + +.. autosummary:: + :toctree: _autosummary + + bvh_first_triangles_hit_by_rays + bvh_rays_intersect_any_triangle + bvh_triangles_visible_from_vertices diff --git a/docs/source/reference/differt.accel.rst b/docs/source/reference/differt.accel.rst new file mode 100644 index 00000000..4a506f30 --- /dev/null +++ b/docs/source/reference/differt.accel.rst @@ -0,0 +1,15 @@ +``differt.accel`` module +======================== + + +.. currentmodule:: differt.accel + +.. automodule:: differt.accel + +Submodules +---------- + +.. toctree:: + :maxdepth: 1 + + differt.accel.bvh diff --git a/docs/source/reference/differt.rst b/docs/source/reference/differt.rst index df101e91..8b9a598a 100644 --- a/docs/source/reference/differt.rst +++ b/docs/source/reference/differt.rst @@ -11,6 +11,7 @@ Submodules .. toctree:: :maxdepth: 1 + differt.accel differt.em differt.geometry differt.plotting