diff --git a/CHANGELOG.md b/CHANGELOG.md index 9512baaf..5e4f5ff3 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -26,12 +26,21 @@ with one *slight* but **important** difference: - Improved Sionna-compatible XML scene parser to support top-level `` materials in addition to nested structures, enabling support for OSM buildings and other XML formats (by , in ). - Added fallback to black color `[0.0, 0.0, 0.0]` when material `` elements are missing, with appropriate warnings logged (by , in ). +- Added the `ray_batch_size` and `tri_batch_size` arguments to {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays` to let users control ray and triangle chunk sizes independently (by , in ). + +### Changed + - Added the {meth}`TriangleMesh.clip`, {meth}`TriangleMesh.keep_all_within`, and {meth}`TriangleMesh.keep_any_within` methods to support clipping and filtering triangle meshes by axis-aligned bounds (by , in ). +- Updated the `batch_size` argument semantics for {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays` so that `batch_size` acts as the default for whichever of `ray_batch_size` and `tri_batch_size` are left unspecified (by , in ). ### Chore - Added tests for the improved Sionna-compatible XML scene parser using OSM building data, ensuring correct parsing of materials and colors (by , in ). +### Perf + +- Reworked the ray-triangle intersection helpers to use double-chunk processing over rays and triangles, reducing peak memory usage while keeping performance high on CPU and GPU (by , in ). + ## [0.8.1](https://github.com/jeertmans/DiffeRT/compare/v0.8.0...v0.8.1) ### Changed diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index ad596846..7fd135f4 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -1,6 +1,7 @@ +import math import typing from collections.abc import Callable, Iterator, Sized -from functools import cache +from functools import cache, partial from typing import TYPE_CHECKING, Any, TypeVar, overload import equinox as eqx @@ -341,6 +342,102 @@ def rays_intersect_triangles( return t, hit +def _ray_intersect_any_triangle_batched( + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, + *, + hit_threshold: Float[Array, ""] | None, + smoothing_factor: Float[ArrayLike, ""] | None, + **kwargs: Any, +) -> Bool[Array, ""] | Float[Array, ""]: + ts, hits = jax.vmap( + partial(rays_intersect_triangles, smoothing_factor=smoothing_factor, **kwargs), + in_axes=(None, None, 0), + )(ray_origin, ray_direction, triangle_vertices) + if smoothing_factor is not None: + return jnp.minimum( + hits, smoothing_function(hit_threshold - ts, smoothing_factor) + ).max(axis=-1, where=active_triangles) + return ((ts < hit_threshold) & hits).any(axis=-1, where=active_triangles) + + +def _ray_intersect_any_triangle( + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, + *, + hit_threshold: Float[Array, ""] | None, + smoothing_factor: Float[ArrayLike, ""] | None, + batch_size: int, + **kwargs: Any, +) -> Bool[Array, ""] | Float[Array, ""]: + def reduce_fn( + left: Bool[Array, ""] | Float[Array, ""], + right: Bool[Array, ""] | Float[Array, ""], + ) -> Bool[Array, ""] | Float[Array, ""]: + if smoothing_factor is not None: + return jnp.maximum(left, right) + return left | right + + def body_fn( + batch_index: Int[Array, ""], + intersect_so_far: Bool[Array, ""] | Float[Array, ""], + ) -> Bool[Array, ""] | Float[Array, ""]: + start_index = batch_index * batch_size + batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( + triangle_vertices, start_index, batch_size, axis=0 + ) + batch_of_active_triangles = ( + jax.lax.dynamic_slice_in_dim( + active_triangles, start_index, batch_size, axis=0 + ) + if active_triangles is not None + else None + ) + intersect_in_batch = _ray_intersect_any_triangle_batched( + ray_origin, + ray_direction, + batch_of_triangle_vertices, + batch_of_active_triangles, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + **kwargs, + ) + + return reduce_fn(intersect_so_far, intersect_in_batch) + + num_triangles = triangle_vertices.shape[0] + num_batches, rem = divmod(num_triangles, batch_size) + + init_val = 0.0 if smoothing_factor is not None else False + + intersect = jax.lax.fori_loop( + 0, + num_batches, + body_fn, + init_val=init_val, + ) + + if rem > 0: + intersect = reduce_fn( + intersect, + _ray_intersect_any_triangle_batched( + ray_origin, + ray_direction, + triangle_vertices[-rem:, :], + (active_triangles[-rem:] if active_triangles is not None else None), + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + **kwargs, + ), + ) + + return intersect + + @overload def rays_intersect_any_triangle( ray_origins: Float[ArrayLike, "*#batch 3"], @@ -351,6 +448,8 @@ def rays_intersect_any_triangle( hit_tol: Float[ArrayLike, ""] | None = ..., smoothing_factor: None = ..., batch_size: int | None = ..., + ray_batch_size: int | None = ..., + tri_batch_size: int | None = ..., **kwargs: Any, ) -> Bool[Array, " *batch"]: ... @@ -365,6 +464,8 @@ def rays_intersect_any_triangle( hit_tol: Float[ArrayLike, ""] | None = ..., smoothing_factor: Float[ArrayLike, ""], batch_size: int | None = ..., + ray_batch_size: int | None = ..., + tri_batch_size: int | None = ..., **kwargs: Any, ) -> Float[Array, " *batch"]: ... @@ -378,7 +479,9 @@ def rays_intersect_any_triangle( *, hit_tol: Float[ArrayLike, ""] | None = None, smoothing_factor: Float[ArrayLike, ""] | None = None, - batch_size: int | None = 512, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, **kwargs: Any, ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: """ @@ -414,13 +517,29 @@ def rays_intersect_any_triangle( between 0 (:data:`False`) and 1 (:data:`True`). For more details, refer to :ref:`smoothing`. - batch_size: The number of triangles to process in a single batch. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. + + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. + ray_batch_size: The number of rays to process in a single batch. This allows to make a trade-off between memory usage and performance. - The batch size is automatically adjusted to be the minimum of the number of triangles - and the specified batch size. + The ray batch size is automatically adjusted to be the minimum of the number of rays + and the specified ray batch size. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all rays are processed in a single batch. + tri_batch_size: The number of triangles to process in a single batch. + This allows to make a trade-off between memory usage and performance. - If :data:`None`, the batch size is set to the number of triangles. + The triangle batch size is automatically adjusted to be the minimum of the number of + triangles and the specified triangle batch size. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all triangles are processed in a single batch. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -438,22 +557,33 @@ def rays_intersect_any_triangle( hit_threshold = 1.0 - jnp.asarray(hit_tol) num_triangles = triangle_vertices.shape[-3] - if batch_size is None: - batch_size = num_triangles - batch_size = max(min(batch_size, num_triangles), 1) - num_batches, rem = divmod(num_triangles, batch_size) if active_triangles is not None: active_triangles = jnp.asarray(active_triangles) - # Combine the batch dimensions batch = jnp.broadcast_shapes( ray_origins.shape[:-1], ray_directions.shape[:-1], triangle_vertices.shape[:-3], active_triangles.shape[:-1] if active_triangles is not None else (), ) - + num_rays = math.prod(batch) + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays + ) + ray_batch_size = max(min(ray_batch_size, num_rays), 1) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles + ) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) if num_triangles == 0: # If there are no triangles, there are no intersections return ( @@ -465,84 +595,396 @@ def rays_intersect_any_triangle( else jnp.zeros(batch, dtype=bool) ) - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: - t, hit = rays_intersect_triangles( - ray_origins[..., None, :], - ray_directions[..., None, :], - triangle_vertices, + if num_rays > ray_batch_size: + xs = [] + argnames = [] + map_fn = partial( + _ray_intersect_any_triangle, + hit_threshold=hit_threshold, smoothing_factor=smoothing_factor, + batch_size=tri_batch_size, **kwargs, ) - if smoothing_factor is not None: - return jnp.minimum( - hit, smoothing_function(hit_threshold - t, smoothing_factor) - ).sum(axis=-1, where=active_triangles) - return ((t < hit_threshold) & hit).any(axis=-1, where=active_triangles) + if math.prod(ray_origins.shape[:-1]) > 1: + ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)) + xs.append(ray_origins.reshape(-1, 3)) + argnames.append("ray_origin") + else: + map_fn = partial(map_fn, ray_origin=ray_origins.reshape(3)) + if math.prod(ray_directions.shape[:-1]) > 1: + ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)) + xs.append(ray_directions.reshape(-1, 3)) + argnames.append("ray_direction") + else: + map_fn = partial(map_fn, ray_direction=ray_directions.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + triangle_vertices = jnp.broadcast_to( + triangle_vertices, (*batch, num_triangles, 3, 3) + ) + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial( + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) + ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + active_triangles = jnp.broadcast_to( + active_triangles, (*batch, num_triangles) + ) + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + ) - def reduce_fn( - left: Bool[Array, " *batch"] | Float[Array, " *batch"], - right: Bool[Array, " *batch"] | Float[Array, " *batch"], - ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: - if smoothing_factor is not None: - return (left + right).clip(max=1.0) - return left | right + def f(*args: Array) -> Bool[Array, ""] | Float[Array, ""]: + return map_fn(**dict(zip(argnames, args, strict=True))) + + return jax.lax.map( + lambda x: f(*x), tuple(xs), batch_size=ray_batch_size + ).reshape(batch) + return jnp.vectorize( + partial( + _ray_intersect_any_triangle, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + batch_size=tri_batch_size, + **kwargs, + ), + signature="(3),(3),(n,3,3),(n)->()" + if active_triangles is not None + else "(3),(3),(n,3,3),()->()", + )(ray_origins, ray_directions, triangle_vertices, active_triangles) + + +def _first_triangle_hit_by_ray_batched( + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "batch_size 3 3"], + active_triangles: Bool[Array, " batch_size"] | None, + triangle_indices: Int[Array, " batch_size"], + *, + dist_tol: Float[Array, ""], + **kwargs: Any, +) -> tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]]: + ts, hits = jax.vmap( + partial(rays_intersect_triangles, **kwargs), + in_axes=(None, None, 0), + )(ray_origin, ray_direction, triangle_vertices) + + if active_triangles is not None: + hits &= active_triangles + + ts = jnp.where(hits, ts, jnp.inf) + min_t = jnp.min(ts) + + center_distances = jnp.linalg.norm( + triangle_vertices.mean(axis=-2) - ray_origin, axis=-1 + ) + + is_close = jnp.abs(ts - min_t) < dist_tol + dist_to_check = jnp.where(is_close, center_distances, jnp.inf) + min_dist = jnp.min(dist_to_check) + + is_best = (dist_to_check == min_dist) & is_close + best_idx = jnp.argmax(is_best) + + any_hit = jnp.any(hits) + + return ( + jnp.where(any_hit, triangle_indices[best_idx], -1), + jnp.where(any_hit, min_t, jnp.inf), + jnp.where(any_hit, min_dist, jnp.inf), + ) + + +def _first_triangle_hit_by_ray( + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, + *, + batch_size: int, + dist_tol: Float[Array, ""], + **kwargs: Any, +) -> tuple[Int[Array, ""], Float[Array, ""]]: + def combine_best( + left: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], + right: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], + ) -> tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]]: + idx1, t1, d1 = left + idx2, t2, d2 = right + + cond = jnp.where( + jnp.abs(t1 - t2) < dist_tol, + d1 < d2, + t1 < t2, + ) + + return ( + jnp.where(cond, idx1, idx2), + jnp.where(cond, t1, t2), + jnp.where(cond, d1, d2), + ) - def body_fun( + def body_fn( batch_index: Int[Array, ""], - intersect: Bool[Array, " *batch"] | Float[Array, " *batch"], - ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: + best_so_far: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], + ) -> tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]]: start_index = batch_index * batch_size batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( - triangle_vertices, start_index, batch_size, axis=-3 + triangle_vertices, start_index, batch_size, axis=0 ) batch_of_active_triangles = ( jax.lax.dynamic_slice_in_dim( - active_triangles, start_index, batch_size, axis=-1 + active_triangles, start_index, batch_size, axis=0 ) if active_triangles is not None else None ) - return reduce_fn( - intersect, - map_fn( - ray_origins=ray_origins, - ray_directions=ray_directions, - triangle_vertices=batch_of_triangle_vertices, - active_triangles=batch_of_active_triangles, - ), + batch_of_indices = jnp.arange(batch_size, dtype=int) + start_index + + best_in_batch = _first_triangle_hit_by_ray_batched( + ray_origin, + ray_direction, + batch_of_triangle_vertices, + batch_of_active_triangles, + batch_of_indices, + dist_tol=dist_tol, + **kwargs, ) + return combine_best(best_so_far, best_in_batch) + + num_triangles = triangle_vertices.shape[0] + num_batches, rem = divmod(num_triangles, batch_size) + init_val = ( - jnp.zeros(batch) - if smoothing_factor is not None - else jnp.zeros(batch, dtype=jnp.bool) + jnp.array(-1, dtype=int), + jnp.array(jnp.inf, dtype=ray_origin.dtype), + jnp.array(jnp.inf, dtype=ray_origin.dtype), ) - intersect = jax.lax.fori_loop( + best = jax.lax.fori_loop( 0, num_batches, - body_fun, + body_fn, init_val=init_val, ) if rem > 0: - return reduce_fn( - intersect, - map_fn( - ray_origins=ray_origins, - ray_directions=ray_directions, - triangle_vertices=triangle_vertices[..., -rem:, :, :], - active_triangles=active_triangles[..., -rem:] - if active_triangles is not None - else None, + start_index = num_batches * batch_size + best = combine_best( + best, + _first_triangle_hit_by_ray_batched( + ray_origin, + ray_direction, + triangle_vertices[-rem:, :], + (active_triangles[-rem:] if active_triangles is not None else None), + jnp.arange(rem, dtype=int) + start_index, + dist_tol=dist_tol, + **kwargs, ), ) - return intersect + + return best[0], best[1] + + +@eqx.filter_jit +def first_triangles_hit_by_rays( + ray_origins: Float[ArrayLike, "*#batch 3"], + ray_directions: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, + **kwargs: Any, +) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: + """ + Return the first triangle hit by each ray. + + This function should be used when allocating an array of size + ``*batch num_triangles 3`` (or bigger) is not possible, and you are only interested in + getting the first triangle hit by the ray. + + If two or more triangles are hit at the same distance, the one with the closest center to the ray origin is selected. Two triangles are considered to be hit at the same distance if their distances differ by less than ``100 * eps``, where ``eps`` is the machine epsilon of the floating point dtype. + + Args: + ray_origins: An array of origin vertices. + ray_directions: An array of ray direction. The ray ends + should be equal to ``ray_origins + ray_directions``. + triangle_vertices: An array of triangle vertices. + active_triangles: An optional array of boolean values indicating + which triangles are active, i.e., should be considered for intersection. + + If not specified, all triangles are considered active. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. + + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. + ray_batch_size: The number of rays to process in a single batch. + This allows to chunk rays and reduce peak memory usage. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all rays are processed in a single batch. + tri_batch_size: The number of triangles to process in a single batch. + This allows to chunk triangles and reduce peak memory usage. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all triangles are processed in a single batch. + kwargs: Keyword arguments passed to + :func:`rays_intersect_triangles`. + + Returns: + For each ray, return the index and the distance to the first triangle hit. + + If no triangle is hit, the index is set to ``-1`` and + the distance is set to :data:`inf`. + """ + ray_origins = jnp.asarray(ray_origins) + ray_directions = jnp.asarray(ray_directions) + triangle_vertices = jnp.asarray(triangle_vertices) + + if active_triangles is not None: + active_triangles = jnp.asarray(active_triangles) + + dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) + dist_tol = jnp.asarray(100 * jnp.finfo(dtype).eps) + + num_triangles = triangle_vertices.shape[-3] + + batch = jnp.broadcast_shapes( + ray_origins.shape[:-1], + ray_directions.shape[:-1], + triangle_vertices.shape[:-3], + active_triangles.shape[:-1] if active_triangles is not None else (), + ) + num_rays = math.prod(batch) + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays + ) + ray_batch_size = max(min(ray_batch_size, num_rays), 1) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles + ) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) + if num_triangles == 0: + # If there are no triangles, there are no hits + return ( + jnp.full(batch, -1, dtype=int), + jnp.full( + batch, + jnp.inf, + dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), + ), + ) + + xs = [] + argnames = [] + map_fn = partial( + _first_triangle_hit_by_ray, + batch_size=tri_batch_size, + dist_tol=dist_tol, + **kwargs, + ) + if math.prod(ray_origins.shape[:-1]) > 1: + xs.append(ray_origins.reshape(-1, 3)) + argnames.append("ray_origin") + else: + map_fn = partial(map_fn, ray_origin=ray_origins.reshape(3)) + if math.prod(ray_directions.shape[:-1]) > 1: + xs.append(ray_directions.reshape(-1, 3)) + argnames.append("ray_direction") + else: + map_fn = partial(map_fn, ray_direction=ray_directions.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial( + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) + ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + ) + + def f(*args: Array) -> tuple[Int[Array, ""], Float[Array, ""]]: + return map_fn(**dict(zip(argnames, args, strict=True))) + + if not xs: + indices, ts = _first_triangle_hit_by_ray( + ray_origins.reshape(3), + ray_directions.reshape(3), + triangle_vertices.reshape(num_triangles, 3, 3), + active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + batch_size=tri_batch_size, + dist_tol=dist_tol, + **kwargs, + ) + return indices.reshape(batch), ts.reshape(batch) + + if num_rays > ray_batch_size: + indices, ts = jax.lax.map(lambda x: f(*x), tuple(xs), batch_size=ray_batch_size) + else: + indices, ts = jax.vmap(f)(*xs) + + return indices.reshape(batch), ts.reshape(batch) + + +def _triangles_visible_from_vertex( + vertex: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, + *, + num_rays: int, + num_triangles: int, + ray_batch_size: int | None, + tri_batch_size: int | None, + world_vertices: Float[Array, "num_world_vertices 3"], + active_vertices: Bool[Array, " num_world_vertices"] | None, + **kwargs: Any, +) -> Bool[Array, " num_triangles"]: + frustum = viewing_frustum(vertex, world_vertices, active_vertices=active_vertices) + ray_directions = fibonacci_lattice(num_rays, frustum=frustum) + + indices, _ = first_triangles_hit_by_rays( + vertex, + ray_directions, + triangle_vertices, + active_triangles=active_triangles, + ray_batch_size=ray_batch_size, + tri_batch_size=tri_batch_size, + **kwargs, + ) + + valid_hits = indices >= 0 + safe_indices = jnp.where(valid_hits, indices, 0) + return jnp.zeros(num_triangles, dtype=bool).at[safe_indices].max(valid_hits) @eqx.filter_jit @@ -551,7 +993,9 @@ def triangles_visible_from_vertices( triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, num_rays: int = int(1e6), - batch_size: int | None = 512, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, **kwargs: Any, ) -> Bool[Array, "*batch num_triangles"]: """ @@ -578,13 +1022,29 @@ def triangles_visible_from_vertices( num_rays: The number of rays to launch. The larger, the more accurate. - batch_size: The number of rays to process in a single batch. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. + + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. + ray_batch_size: The number of rays to process in a single batch. + This allows to make a trade-off between memory usage and performance. + + The ray batch size is automatically adjusted to be the minimum of the number of rays + and the specified ray batch size. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all rays are processed in a single batch. + tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. - The batch size is automatically adjusted to be the minimum of the number of rays - and the specified batch size. + The triangle batch size is automatically adjusted to be the minimum of the number of + triangles and the specified triangle batch size. - If :data:`None`, the batch size is set to the number of rays. + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all triangles are processed in a single batch. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -672,10 +1132,8 @@ def triangles_visible_from_vertices( """ vertices = jnp.asarray(vertices) triangle_vertices = jnp.asarray(triangle_vertices) - triangle_centers = triangle_vertices.mean(axis=-2, keepdims=True) - world_vertices = jnp.concat((triangle_vertices, triangle_centers), axis=-2).reshape( - *triangle_vertices.shape[:-3], -1, 3 - ) + + num_triangles = triangle_vertices.shape[-3] if active_triangles is not None: active_triangles = jnp.asarray(active_triangles) @@ -683,307 +1141,111 @@ def triangles_visible_from_vertices( else: active_vertices = None - # [*batch 3] - ray_origins = vertices - - # [*batch 2 3] - frustum = viewing_frustum( - ray_origins, - world_vertices, - active_vertices=active_vertices, + triangle_centers = triangle_vertices.mean(axis=-2, keepdims=True) + world_vertices = jnp.concat((triangle_vertices, triangle_centers), axis=-2).reshape( + *triangle_vertices.shape[:-3], -1, 3 ) - batch_size = num_rays if batch_size is None else min(batch_size, num_rays) - num_batches, rem = divmod(num_rays, batch_size) - - # [*batch num_rays 3] - ray_directions = jnp.vectorize( - lambda n, frustum: fibonacci_lattice(n, frustum=frustum), - excluded={0}, - signature="(2,3)->(n,3)", - )(num_rays, frustum) - - # Combine the batch dimensions batch = jnp.broadcast_shapes( - ray_origins.shape[:-2], - ray_directions.shape[:-2], + vertices.shape[:-1], triangle_vertices.shape[:-3], active_triangles.shape[:-1] if active_triangles is not None else (), ) - def update_visible_triangles( - visible_triangles: Bool[Array, "*#batch num_triangles"], - visible_indices: Int[Array, "*batch batch_size"], - ) -> Bool[Array, "*#batch num_triangles"]: - indices = jnp.indices(visible_triangles.shape, sparse=True) - indices = (*indices[:-1], visible_indices) - return visible_triangles.at[indices].set(True, wrap_negative_indices=False) - - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch batch_size 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> Int[Array, "*batch batch_size"]: - indices, _ = first_triangles_hit_by_rays( - ray_origins[..., None, :], - ray_directions, - triangle_vertices[..., None, :, :, :], - active_triangles=active_triangles[..., None, :] - if active_triangles is not None - else None, - batch_size=None, - **kwargs, - ) - return indices - - def body_fun( - batch_index: Int[Array, ""], - visible_triangles: Bool[Array, "*batch num_triangles"], - ) -> Bool[Array, "*batch num_triangles"]: - start_index = batch_index * batch_size - batch_of_ray_directions = jax.lax.dynamic_slice_in_dim( - ray_directions, start_index, batch_size, axis=-2 - ) - visible_indices = map_fn( - ray_origins, - batch_of_ray_directions, - triangle_vertices, - active_triangles, - ) - return update_visible_triangles(visible_triangles, visible_indices) - - init_val = jnp.zeros((*batch, triangle_vertices.shape[-3]), dtype=jnp.bool) - - visible_triangles = jax.lax.fori_loop( - 0, - num_batches, - body_fun, - init_val=init_val, + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays ) - - if rem > 0: - visible_indices = map_fn( - ray_origins, - ray_directions[..., -rem:, :], - triangle_vertices, - active_triangles, - ) - return update_visible_triangles(visible_triangles, visible_indices) - return visible_triangles - - -@eqx.filter_jit -def first_triangles_hit_by_rays( - ray_origins: Float[ArrayLike, "*#batch 3"], - ray_directions: Float[ArrayLike, "*#batch 3"], - triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], - active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, - batch_size: int | None = 512, - **kwargs: Any, -) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: - """ - Return the first triangle hit by each ray. - - This function should be used when allocating an array of size - ``*batch num_triangles 3`` (or bigger) is not possible, and you are only interested in - getting the first triangle hit by the ray. - - If two or more triangles are hit at the same distance, the one with the closest center to the ray origin is selected. Two triangles are considered to be hit at the same distance if their distances differ by less than ``100 * eps``, or ten times the ``epsilon`` keyword argument passed to :func:`rays_intersect_triangles`. - - Args: - ray_origins: An array of origin vertices. - ray_directions: An array of ray direction. The ray ends - should be equal to ``ray_origins + ray_directions``. - triangle_vertices: An array of triangle vertices. - active_triangles: An optional array of boolean values indicating - which triangles are active, i.e., should be considered for intersection. - - If not specified, all triangles are considered active. - batch_size: The number of triangles to process in a single batch. - This allows to make a trade-off between memory usage and performance. - - The batch size is automatically adjusted to be the minimum of the number of triangles - and the specified batch size. - - If :data:`None`, the batch size is set to the number of triangles. - kwargs: Keyword arguments passed to - :func:`rays_intersect_triangles`. - - Returns: - For each ray, return the index and to distance to the first triangle hit. - - If no triangle is hit, the index is set to ``-1`` and - the distance is set to :data:`inf`. - """ - ray_origins = jnp.asarray(ray_origins) - ray_directions = jnp.asarray(ray_directions) - triangle_vertices = jnp.asarray(triangle_vertices) - - if epsilon := kwargs.get("epsilon"): - epsilon = 10 * jnp.asarray(epsilon) - else: - dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) - epsilon = jnp.asarray(100 * jnp.finfo(dtype).eps) - - num_triangles = triangle_vertices.shape[-3] - if batch_size is None: - batch_size = num_triangles - batch_size = max(min(batch_size, num_triangles), 1) - num_batches, rem = divmod(num_triangles, batch_size) - - if active_triangles is not None: - active_triangles = jnp.asarray(active_triangles) - - # Combine the batch dimensions - batch = jnp.broadcast_shapes( - ray_origins.shape[:-1], - ray_directions.shape[:-1], - triangle_vertices.shape[:-3], - active_triangles.shape[:-1] if active_triangles is not None else (), + ray_batch_size = max(min(ray_batch_size, num_rays), 1) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles ) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) if num_triangles == 0: - # If there are no triangles, there are no hits - return ( - jnp.full(batch, -1, dtype=jnp.int32), - jnp.full( - batch, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ), - ) - - def reduce_fn( - left: tuple[ - Int[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *#batch"], - ], - right: tuple[ - Int[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *#batch"], - ], - ) -> tuple[ - Int[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *#batch"], - ]: - left_indices, left_t, left_center_distances, eps = left - right_indices, right_t, right_center_distances, _ = right - cond: Array = jnp.where( - jnp.abs(left_t - right_t) < eps, - left_center_distances < right_center_distances, - left_t < right_t, - ) - t = jnp.where(cond, left_t, right_t) - indices = jnp.where(cond, left_indices, right_indices) - t = jnp.minimum(left_t, right_t) - center_distances = jnp.where( - cond, left_center_distances, right_center_distances - ) - is_finite = jnp.isfinite(t) - indices = jnp.where(is_finite, indices, -1) - t = jnp.where(is_finite, t, jnp.inf) - return indices, t, center_distances, eps - - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"]]: - t, hit = rays_intersect_triangles( - ray_origins[..., None, :], - ray_directions[..., None, :], - triangle_vertices, - **kwargs, - ) - if active_triangles is not None: - hit &= active_triangles - t = jnp.where(hit, t, jnp.inf) - indices = jnp.arange(triangle_vertices.shape[-3]) - indices = jnp.broadcast_to(indices, t.shape) - center_distances = jnp.linalg.norm( - triangle_vertices.mean(axis=-2) - ray_origins[..., None, :], axis=-1 - ) - center_distances = jnp.broadcast_to(center_distances, t.shape) - eps = jnp.broadcast_to(epsilon, t.shape) - return jax.lax.reduce( - (indices, t, center_distances, eps), - (-1, jnp.inf, jnp.inf, epsilon), - reduce_fn, - dimensions=(t.ndim - 1,), - )[:3] - - def body_fun( - batch_index: Int[Array, ""], - carry: tuple[ - Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] - ], - ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"]]: - start_index = batch_index * batch_size - batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( - triangle_vertices, start_index, batch_size, axis=-3 + return jnp.zeros((*batch, 0), dtype=jnp.bool_) + + xs = [] + argnames = [] + map_fn = partial( + _triangles_visible_from_vertex, + num_rays=num_rays, + num_triangles=num_triangles, + ray_batch_size=ray_batch_size, + tri_batch_size=tri_batch_size, + **kwargs, + ) + if math.prod(vertices.shape[:-1]) > 1: + xs.append(vertices.reshape(-1, 3)) + argnames.append("vertex") + else: + map_fn = partial(map_fn, vertex=vertices.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial( + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) ) - batch_of_active_triangles = ( - jax.lax.dynamic_slice_in_dim( - active_triangles, start_index, batch_size, axis=-1 - ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) if active_triangles is not None - else None + else None, ) - indices, t, center_distances = map_fn( - ray_origins, - ray_directions, - batch_of_triangle_vertices, - batch_of_active_triangles, + + # Also need to handle world_vertices and active_vertices broadcasting + if math.prod(world_vertices.shape[:-2]) > 1: + xs.append(world_vertices.reshape(-1, *world_vertices.shape[-2:])) + argnames.append("world_vertices") + else: + map_fn = partial(map_fn, world_vertices=world_vertices.reshape(-1, 3)) + + if active_vertices is not None and math.prod(active_vertices.shape[:-1]) > 1: + xs.append(active_vertices.reshape(-1, active_vertices.shape[-1])) + argnames.append("active_vertices") + else: + map_fn = partial( + map_fn, + active_vertices=active_vertices.reshape(-1) + if active_vertices is not None + else None, ) - # TODO: use *carry when ty supports starred expressions - return reduce_fn( - (carry[0], carry[1], carry[2], epsilon), - (indices + start_index, t, center_distances, epsilon), - )[:3] - init_val = ( - -jnp.ones(batch, dtype=jnp.int32), - jnp.full( - batch, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ), - jnp.full( - batch, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ), - ) + def f(args: tuple[Array, ...]) -> Bool[Array, "num_triangles"]: + return map_fn(**dict(zip(argnames, args, strict=True))) - indices, t, center_distances = jax.lax.fori_loop( - 0, - num_batches, - body_fun, - init_val=init_val, - ) + if not xs: + # Case where everything is shared or num_vertices == 1 + return _triangles_visible_from_vertex( + vertices.reshape(3), + triangle_vertices.reshape(num_triangles, 3, 3), + active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + num_rays=num_rays, + num_triangles=num_triangles, + ray_batch_size=ray_batch_size, + tri_batch_size=tri_batch_size, + world_vertices=world_vertices.reshape(-1, 3), + active_vertices=active_vertices.reshape(-1) + if active_vertices is not None + else None, + **kwargs, + ).reshape((*batch, num_triangles)) - if rem > 0: - rem_indices, rem_t, rem_center_distances = map_fn( - ray_origins, - ray_directions, - triangle_vertices[..., -rem:, :, :], - active_triangles[..., -rem:] if active_triangles is not None else None, - ) - return reduce_fn( - (indices, t, center_distances, epsilon), - ( - rem_indices + num_batches * batch_size, - rem_t, - rem_center_distances, - epsilon, - ), - )[:2] - return (indices, t) + return jax.lax.map(f, tuple(xs), batch_size=batch_size).reshape(( + *batch, + num_triangles, + )) diff --git a/differt/src/differt/scene/_triangle_scene.py b/differt/src/differt/scene/_triangle_scene.py index ad6ebea6..7169bdcf 100644 --- a/differt/src/differt/scene/_triangle_scene.py +++ b/differt/src/differt/scene/_triangle_scene.py @@ -1031,7 +1031,7 @@ def compute_paths( max_dist: Float[ArrayLike, ""] = 1e-3, smoothing_factor: Float[ArrayLike, ""] | None = None, confidence_threshold: Float[ArrayLike, ""] = 0.5, - batch_size: int | None = 512, + batch_size: int | None = 1024, disconnect_inactive_triangles: bool = False, ) -> Paths[_M] | SizedIterator[Paths[_M]] | Iterator[Paths[_M]] | SBRPaths: """ diff --git a/differt/tests/rt/test_utils.py b/differt/tests/rt/test_utils.py index 2e5723b5..3576c624 100644 --- a/differt/tests/rt/test_utils.py +++ b/differt/tests/rt/test_utils.py @@ -571,4 +571,6 @@ def test_first_triangles_hit_by_rays( # TODO: fixme, we need to fix the index if two or more triangles are hit at the same t # chex.assert_trees_all_equal(got_indices, expected_indices) - chex.assert_trees_all_close(got_t, expected_t, rtol=1e-5) + dtype = got_t.dtype + dist_tol = 100 * jnp.finfo(dtype).eps + chex.assert_trees_all_close(got_t, expected_t, rtol=1e-5, atol=dist_tol)