From a772402fb33abfe75a3980abe73c3faa4aec6d33 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 10:36:50 +0200 Subject: [PATCH 01/13] feat(lib): implement double-chunked ray-triangle tests for better perf. --- differt/src/differt/rt/_utils.py | 612 ++++++++++++++++++++----------- 1 file changed, 401 insertions(+), 211 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index ad596846..718e9ba2 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -1,3 +1,4 @@ +import math import typing from collections.abc import Callable, Iterator, Sized from functools import cache @@ -351,6 +352,8 @@ def rays_intersect_any_triangle( hit_tol: Float[ArrayLike, ""] | None = ..., smoothing_factor: None = ..., batch_size: int | None = ..., + ray_batch_size: int | None = ..., + tri_batch_size: int | None = ..., **kwargs: Any, ) -> Bool[Array, " *batch"]: ... @@ -365,6 +368,7 @@ def rays_intersect_any_triangle( hit_tol: Float[ArrayLike, ""] | None = ..., smoothing_factor: Float[ArrayLike, ""], batch_size: int | None = ..., + ray_batch_size: int | None = ..., **kwargs: Any, ) -> Float[Array, " *batch"]: ... @@ -378,7 +382,9 @@ def rays_intersect_any_triangle( *, hit_tol: Float[ArrayLike, ""] | None = None, smoothing_factor: Float[ArrayLike, ""] | None = None, - batch_size: int | None = 512, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, **kwargs: Any, ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: """ @@ -414,13 +420,24 @@ def rays_intersect_any_triangle( between 0 (:data:`False`) and 1 (:data:`True`). For more details, refer to :ref:`smoothing`. - batch_size: The number of triangles to process in a single batch. + batch_size: The default batch size used when either the ray or triangle batch size is + not specified. This allows to make a trade-off between memory usage and performance. + + If :data:`None`, a heuristic based on the input sizes is used. + ray_batch_size: The number of rays to process in a single batch. + This allows to make a trade-off between memory usage and performance. + + The ray batch size is automatically adjusted to be the minimum of the number of rays + and the specified ray batch size. + + If :data:`None`, the ray batch size is set to the number of rays. + tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. - The batch size is automatically adjusted to be the minimum of the number of triangles - and the specified batch size. + The triangle batch size is automatically adjusted to be the minimum of the number of + triangles and the specified triangle batch size. - If :data:`None`, the batch size is set to the number of triangles. + If :data:`None`, the triangle batch size defaults to :data:`batch_size`. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -438,15 +455,10 @@ def rays_intersect_any_triangle( hit_threshold = 1.0 - jnp.asarray(hit_tol) num_triangles = triangle_vertices.shape[-3] - if batch_size is None: - batch_size = num_triangles - batch_size = max(min(batch_size, num_triangles), 1) - num_batches, rem = divmod(num_triangles, batch_size) if active_triangles is not None: active_triangles = jnp.asarray(active_triangles) - # Combine the batch dimensions batch = jnp.broadcast_shapes( ray_origins.shape[:-1], ray_directions.shape[:-1], @@ -454,6 +466,18 @@ def rays_intersect_any_triangle( active_triangles.shape[:-1] if active_triangles is not None else (), ) + num_rays = math.prod(batch) + # If user passed None explicitly, process everything in a single batch + if batch_size is None: + batch_size = max(num_rays, num_triangles) + if ray_batch_size is None: + ray_batch_size = batch_size + if tri_batch_size is None: + tri_batch_size = batch_size + + ray_chunk_size = max(min(ray_batch_size, num_rays), 1) + tri_chunk_size = max(min(tri_batch_size, num_triangles), 1) + if num_triangles == 0: # If there are no triangles, there are no intersections return ( @@ -465,84 +489,122 @@ def rays_intersect_any_triangle( else jnp.zeros(batch, dtype=bool) ) - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: - t, hit = rays_intersect_triangles( - ray_origins[..., None, :], - ray_directions[..., None, :], - triangle_vertices, - smoothing_factor=smoothing_factor, - **kwargs, + ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)).reshape(-1, 3) + ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)).reshape(-1, 3) + + pad_rays_len = (ray_chunk_size - (num_rays % ray_chunk_size)) % ray_chunk_size + if pad_rays_len > 0: + ray_origins = jnp.pad(ray_origins, ((0, pad_rays_len), (0, 0))) + ray_directions = jnp.pad(ray_directions, ((0, pad_rays_len), (0, 0))) + + num_rays_chunks = ray_origins.shape[0] // ray_chunk_size + blocked_ro = ray_origins.reshape(num_rays_chunks, ray_chunk_size, 3) + blocked_rd = ray_directions.reshape(num_rays_chunks, ray_chunk_size, 3) + + pad_tris_len = (tri_chunk_size - (num_triangles % tri_chunk_size)) % tri_chunk_size + + has_batch_tris = triangle_vertices.shape[:-3] != () + if has_batch_tris: + triangle_vertices = jnp.broadcast_to( + triangle_vertices, (*batch, num_triangles, 3, 3) + ).reshape(-1, num_triangles, 3, 3) + if pad_rays_len > 0 or pad_tris_len > 0: + triangle_vertices = jnp.pad( + triangle_vertices, + ((0, pad_rays_len), (0, pad_tris_len), (0, 0), (0, 0)), + ) + blocked_tris = triangle_vertices.reshape( + num_rays_chunks, ray_chunk_size, -1, tri_chunk_size, 3, 3 ) - if smoothing_factor is not None: - return jnp.minimum( - hit, smoothing_function(hit_threshold - t, smoothing_factor) - ).sum(axis=-1, where=active_triangles) - return ((t < hit_threshold) & hit).any(axis=-1, where=active_triangles) - - def reduce_fn( - left: Bool[Array, " *batch"] | Float[Array, " *batch"], - right: Bool[Array, " *batch"] | Float[Array, " *batch"], - ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: - if smoothing_factor is not None: - return (left + right).clip(max=1.0) - return left | right + else: + if pad_tris_len > 0: + triangle_vertices = jnp.pad( + triangle_vertices, ((0, pad_tris_len), (0, 0), (0, 0)) + ) + blocked_tris = triangle_vertices.reshape(-1, tri_chunk_size, 3, 3) - def body_fun( - batch_index: Int[Array, ""], - intersect: Bool[Array, " *batch"] | Float[Array, " *batch"], - ) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: - start_index = batch_index * batch_size - batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( - triangle_vertices, start_index, batch_size, axis=-3 - ) - batch_of_active_triangles = ( - jax.lax.dynamic_slice_in_dim( - active_triangles, start_index, batch_size, axis=-1 + if active_triangles is not None: + has_batch_active = active_triangles.shape[:-1] != () + if has_batch_active: + active_triangles = jnp.broadcast_to( + active_triangles, (*batch, num_triangles) + ).reshape(-1, num_triangles) + if pad_rays_len > 0 or pad_tris_len > 0: + active_triangles = jnp.pad( + active_triangles, ((0, pad_rays_len), (0, pad_tris_len)) + ) + blocked_active = active_triangles.reshape( + num_rays_chunks, ray_chunk_size, -1, tri_chunk_size ) - if active_triangles is not None - else None - ) - return reduce_fn( - intersect, - map_fn( - ray_origins=ray_origins, - ray_directions=ray_directions, - triangle_vertices=batch_of_triangle_vertices, - active_triangles=batch_of_active_triangles, - ), + else: + if pad_tris_len > 0: + active_triangles = jnp.pad(active_triangles, ((0, pad_tris_len),)) + blocked_active = active_triangles.reshape(-1, tri_chunk_size) + else: + has_batch_active = False + blocked_active = None + + xs_rays = [blocked_ro, blocked_rd] + if has_batch_tris: + xs_rays.append(blocked_tris) + if active_triangles is not None and has_batch_active: + xs_rays.append(blocked_active) + + def scan_rays(carry_rays, ray_chunk): + ro_block = ray_chunk[0] + rd_block = ray_chunk[1] + + idx = 2 + if has_batch_tris: + tris_block_batch = jnp.swapaxes(ray_chunk[idx], 0, 1) + idx += 1 + else: + tris_block_batch = blocked_tris + + if active_triangles is not None and has_batch_active: + active_block_batch = jnp.swapaxes(ray_chunk[idx], 0, 1) + elif active_triangles is not None: + active_block_batch = blocked_active + else: + active_block_batch = None + + xs_tris = [tris_block_batch] + if active_block_batch is not None: + xs_tris.append(active_block_batch) + + def scan_tris(carry_tris, tris_chunk): + tris_block = tris_chunk[0] + active_block = tris_chunk[1] if len(tris_chunk) > 1 else None + + t, hit = rays_intersect_triangles( + ro_block[..., None, :], + rd_block[..., None, :], + tris_block, + smoothing_factor=smoothing_factor, + **kwargs, + ) + + if smoothing_factor is not None: + block_hits = jnp.minimum( + hit, smoothing_function(hit_threshold - t, smoothing_factor) + ).sum(axis=-1, where=active_block) + return (carry_tris + block_hits).clip(max=1.0), None + block_hits = ((t < hit_threshold) & hit).any(axis=-1, where=active_block) + return carry_tris | block_hits, None + + init_val = ( + jnp.zeros(ray_chunk_size) + if smoothing_factor is not None + else jnp.zeros(ray_chunk_size, dtype=bool) ) - init_val = ( - jnp.zeros(batch) - if smoothing_factor is not None - else jnp.zeros(batch, dtype=jnp.bool) - ) + hits_for_chunk, _ = jax.lax.scan(scan_tris, init=init_val, xs=tuple(xs_tris)) + return carry_rays, hits_for_chunk - intersect = jax.lax.fori_loop( - 0, - num_batches, - body_fun, - init_val=init_val, - ) + _, all_hits = jax.lax.scan(scan_rays, init=None, xs=tuple(xs_rays)) - if rem > 0: - return reduce_fn( - intersect, - map_fn( - ray_origins=ray_origins, - ray_directions=ray_directions, - triangle_vertices=triangle_vertices[..., -rem:, :, :], - active_triangles=active_triangles[..., -rem:] - if active_triangles is not None - else None, - ), - ) - return intersect + hits = all_hits.reshape(-1) + return hits[:num_rays].reshape(batch) @eqx.filter_jit @@ -551,7 +613,9 @@ def triangles_visible_from_vertices( triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, num_rays: int = int(1e6), - batch_size: int | None = 512, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, **kwargs: Any, ) -> Bool[Array, "*batch num_triangles"]: """ @@ -578,13 +642,24 @@ def triangles_visible_from_vertices( num_rays: The number of rays to launch. The larger, the more accurate. - batch_size: The number of rays to process in a single batch. + batch_size: The default batch size used when either the ray or triangle batch size is + not specified. This allows to make a trade-off between memory usage and performance. + + If :data:`None`, a heuristic based on the input sizes is used. + ray_batch_size: The number of rays to process in a single batch. + This allows to make a trade-off between memory usage and performance. + + The ray batch size is automatically adjusted to be the minimum of the number of rays + and the specified ray batch size. + + If :data:`None`, the ray batch size is set to the number of rays. + tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. - The batch size is automatically adjusted to be the minimum of the number of rays - and the specified batch size. + The triangle batch size is automatically adjusted to be the minimum of the number of + triangles and the specified triangle batch size. - If :data:`None`, the batch size is set to the number of rays. + If :data:`None`, the triangle batch size defaults to :data:`batch_size`. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -693,9 +768,6 @@ def triangles_visible_from_vertices( active_vertices=active_vertices, ) - batch_size = num_rays if batch_size is None else min(batch_size, num_rays) - num_batches, rem = divmod(num_rays, batch_size) - # [*batch num_rays 3] ray_directions = jnp.vectorize( lambda n, frustum: fibonacci_lattice(n, frustum=frustum), @@ -711,65 +783,72 @@ def triangles_visible_from_vertices( active_triangles.shape[:-1] if active_triangles is not None else (), ) + num_triangles = triangle_vertices.shape[-3] + # None means process everything in a single batch + if batch_size is None: + batch_size = max(num_rays, num_triangles) + if ray_batch_size is None: + ray_batch_size = batch_size + if tri_batch_size is None: + tri_batch_size = batch_size + ray_batch_size = max(min(ray_batch_size, num_rays), 1) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) + + num_ray_batches, rem_rays = divmod(num_rays, ray_batch_size) + def update_visible_triangles( visible_triangles: Bool[Array, "*#batch num_triangles"], - visible_indices: Int[Array, "*batch batch_size"], + ray_directions_batch: Float[Array, "*#batch batch_rays 3"], ) -> Bool[Array, "*#batch num_triangles"]: - indices = jnp.indices(visible_triangles.shape, sparse=True) - indices = (*indices[:-1], visible_indices) - return visible_triangles.at[indices].set(True, wrap_negative_indices=False) - - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch batch_size 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> Int[Array, "*batch batch_size"]: + """Check which triangles are visible from rays in this batch.""" indices, _ = first_triangles_hit_by_rays( ray_origins[..., None, :], - ray_directions, + ray_directions_batch, triangle_vertices[..., None, :, :, :], active_triangles=active_triangles[..., None, :] if active_triangles is not None else None, - batch_size=None, + batch_size=tri_batch_size, + ray_batch_size=ray_directions_batch.shape[-2], + tri_batch_size=tri_batch_size, **kwargs, ) - return indices + # indices: [*batch ray_batch_size], value >= 0 means triangle index was hit + # Convert to per-triangle: for each triangle, check if any ray hit it + indices_expanded = indices[..., None] # [*batch ray_batch_size 1] + triangles_range = jnp.arange(num_triangles) # [num_triangles] + indices_one_hot = (indices_expanded == triangles_range) & ( + indices_expanded >= 0 + ) + # Reduce over rays: [*batch num_triangles] + hit_any_ray = jnp.any(indices_one_hot, axis=-2) + return visible_triangles | hit_any_ray def body_fun( batch_index: Int[Array, ""], visible_triangles: Bool[Array, "*batch num_triangles"], ) -> Bool[Array, "*batch num_triangles"]: - start_index = batch_index * batch_size + start_index = batch_index * ray_batch_size batch_of_ray_directions = jax.lax.dynamic_slice_in_dim( - ray_directions, start_index, batch_size, axis=-2 - ) - visible_indices = map_fn( - ray_origins, - batch_of_ray_directions, - triangle_vertices, - active_triangles, + ray_directions, start_index, ray_batch_size, axis=-2 ) - return update_visible_triangles(visible_triangles, visible_indices) + return update_visible_triangles(visible_triangles, batch_of_ray_directions) - init_val = jnp.zeros((*batch, triangle_vertices.shape[-3]), dtype=jnp.bool) + init_val = jnp.zeros((*batch, num_triangles), dtype=jnp.bool_) visible_triangles = jax.lax.fori_loop( 0, - num_batches, + num_ray_batches, body_fun, init_val=init_val, ) - if rem > 0: - visible_indices = map_fn( - ray_origins, - ray_directions[..., -rem:, :], - triangle_vertices, - active_triangles, + if rem_rays > 0: + batch_of_ray_directions = ray_directions[..., -rem_rays:, :] + visible_triangles = update_visible_triangles( + visible_triangles, batch_of_ray_directions ) - return update_visible_triangles(visible_triangles, visible_indices) + return visible_triangles @@ -779,7 +858,9 @@ def first_triangles_hit_by_rays( ray_directions: Float[ArrayLike, "*#batch 3"], triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, - batch_size: int | None = 512, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, **kwargs: Any, ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: """ @@ -800,13 +881,18 @@ def first_triangles_hit_by_rays( which triangles are active, i.e., should be considered for intersection. If not specified, all triangles are considered active. - batch_size: The number of triangles to process in a single batch. - This allows to make a trade-off between memory usage and performance. + batch_size: The default batch size used when either the ray or triangle batch size is + not specified. This allows to make a trade-off between memory usage and performance. + + If :data:`None`, a heuristic based on the input sizes is used. + ray_batch_size: The number of rays to process in a single batch. + This allows to chunk rays and reduce peak memory usage. - The batch size is automatically adjusted to be the minimum of the number of triangles - and the specified batch size. + If :data:`None`, all rays are processed together. + tri_batch_size: The number of triangles to process in a single batch. + This allows to chunk triangles and reduce peak memory usage. - If :data:`None`, the batch size is set to the number of triangles. + If :data:`None`, the triangle batch size defaults to :data:`batch_size`. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -827,10 +913,6 @@ def first_triangles_hit_by_rays( epsilon = jnp.asarray(100 * jnp.finfo(dtype).eps) num_triangles = triangle_vertices.shape[-3] - if batch_size is None: - batch_size = num_triangles - batch_size = max(min(batch_size, num_triangles), 1) - num_batches, rem = divmod(num_triangles, batch_size) if active_triangles is not None: active_triangles = jnp.asarray(active_triangles) @@ -843,6 +925,18 @@ def first_triangles_hit_by_rays( active_triangles.shape[:-1] if active_triangles is not None else (), ) + num_rays = math.prod(batch) + # None means process everything in a single batch + if batch_size is None: + batch_size = max(num_rays, num_triangles) + if ray_batch_size is None: + ray_batch_size = batch_size + if tri_batch_size is None: + tri_batch_size = batch_size + + ray_chunk_size = max(min(ray_batch_size, num_rays), 1) + tri_chunk_size = max(min(tri_batch_size, num_triangles), 1) + if num_triangles == 0: # If there are no triangles, there are no hits return ( @@ -891,99 +985,195 @@ def reduce_fn( t = jnp.where(is_finite, t, jnp.inf) return indices, t, center_distances, eps - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"]]: - t, hit = rays_intersect_triangles( - ray_origins[..., None, :], - ray_directions[..., None, :], - triangle_vertices, - **kwargs, + def _process_ray_batch( + ray_origins_batch: Float[Array, "*#batch 3"], + ray_directions_batch: Float[Array, "*#batch 3"], + triangle_vertices_batch: Float[Array, "*#batch num_triangles 3 3"], + active_triangles_batch: Bool[Array, "*#batch num_triangles"] | None = None, + ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: + """Process one batch of rays through all triangles.""" + + def map_fn( + ray_origins: Float[Array, "*#batch 3"], + ray_directions: Float[Array, "*#batch 3"], + triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], + active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, + ) -> tuple[ + Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] + ]: + t, hit = rays_intersect_triangles( + ray_origins[..., None, :], + ray_directions[..., None, :], + triangle_vertices, + **kwargs, + ) + if active_triangles is not None: + hit &= active_triangles + t = jnp.where(hit, t, jnp.inf) + indices = jnp.arange(triangle_vertices.shape[-3]) + indices = jnp.broadcast_to(indices, t.shape) + center_distances = jnp.linalg.norm( + triangle_vertices.mean(axis=-2) - ray_origins[..., None, :], axis=-1 + ) + center_distances = jnp.broadcast_to(center_distances, t.shape) + eps = jnp.broadcast_to(epsilon, t.shape) + return jax.lax.reduce( + (indices, t, center_distances, eps), + (-1, jnp.inf, jnp.inf, epsilon), + reduce_fn, + dimensions=(t.ndim - 1,), + )[:3] + + num_triangle_batches, rem_triangles = divmod(num_triangles, tri_chunk_size) + + def body_fun( + batch_index: Int[Array, ""], + carry: tuple[ + Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] + ], + ) -> tuple[ + Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] + ]: + start_index = batch_index * tri_chunk_size + batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( + triangle_vertices_batch, start_index, tri_chunk_size, axis=-3 + ) + batch_of_active_triangles = ( + jax.lax.dynamic_slice_in_dim( + active_triangles_batch, start_index, tri_chunk_size, axis=-1 + ) + if active_triangles_batch is not None + else None + ) + indices, t, center_distances = map_fn( + ray_origins_batch, + ray_directions_batch, + batch_of_triangle_vertices, + batch_of_active_triangles, + ) + return reduce_fn( + (carry[0], carry[1], carry[2], epsilon), + (indices + start_index, t, center_distances, epsilon), + )[:3] + + init_val = ( + -jnp.ones(ray_origins_batch.shape[:-1], dtype=jnp.int32), + jnp.full( + ray_origins_batch.shape[:-1], + jnp.inf, + dtype=jnp.result_type( + ray_origins_batch, ray_directions_batch, triangle_vertices_batch + ), + ), + jnp.full( + ray_origins_batch.shape[:-1], + jnp.inf, + dtype=jnp.result_type( + ray_origins_batch, ray_directions_batch, triangle_vertices_batch + ), + ), ) - if active_triangles is not None: - hit &= active_triangles - t = jnp.where(hit, t, jnp.inf) - indices = jnp.arange(triangle_vertices.shape[-3]) - indices = jnp.broadcast_to(indices, t.shape) - center_distances = jnp.linalg.norm( - triangle_vertices.mean(axis=-2) - ray_origins[..., None, :], axis=-1 + + indices, t, center_distances = jax.lax.fori_loop( + 0, + num_triangle_batches, + body_fun, + init_val=init_val, ) - center_distances = jnp.broadcast_to(center_distances, t.shape) - eps = jnp.broadcast_to(epsilon, t.shape) - return jax.lax.reduce( - (indices, t, center_distances, eps), - (-1, jnp.inf, jnp.inf, epsilon), - reduce_fn, - dimensions=(t.ndim - 1,), - )[:3] - def body_fun( + if rem_triangles > 0: + rem_indices, rem_t, rem_center_distances = map_fn( + ray_origins_batch, + ray_directions_batch, + triangle_vertices_batch[..., -rem_triangles:, :, :], + active_triangles_batch[..., -rem_triangles:] + if active_triangles_batch is not None + else None, + ) + indices, t, _ = reduce_fn( + (indices, t, center_distances, epsilon), + ( + rem_indices + num_triangle_batches * tri_chunk_size, + rem_t, + rem_center_distances, + epsilon, + ), + )[:3] + + return (indices, t) + + ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)).reshape(-1, 3) + ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)).reshape(-1, 3) + triangle_vertices = jnp.broadcast_to( + triangle_vertices, (*batch, num_triangles, 3, 3) + ).reshape(-1, num_triangles, 3, 3) + if active_triangles is not None: + active_triangles = jnp.broadcast_to( + active_triangles, (*batch, num_triangles) + ).reshape(-1, num_triangles) + + num_ray_batches, rem_rays = divmod(num_rays, ray_chunk_size) + + init_indices = -jnp.ones(num_rays, dtype=jnp.int32) + init_t = jnp.full( + num_rays, + jnp.inf, + dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), + ) + + def ray_body_fun( batch_index: Int[Array, ""], - carry: tuple[ - Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] - ], - ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"]]: - start_index = batch_index * batch_size - batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( - triangle_vertices, start_index, batch_size, axis=-3 + carry: tuple[Int[Array, "*"], Float[Array, "*"]], + ) -> tuple[Int[Array, "*"], Float[Array, "*"]]: + start_index = batch_index * ray_chunk_size + ray_origins_batch = jax.lax.dynamic_slice_in_dim( + ray_origins, start_index, ray_chunk_size, axis=0 ) - batch_of_active_triangles = ( + ray_directions_batch = jax.lax.dynamic_slice_in_dim( + ray_directions, start_index, ray_chunk_size, axis=0 + ) + triangle_vertices_batch = jax.lax.dynamic_slice_in_dim( + triangle_vertices, start_index, ray_chunk_size, axis=0 + ) + active_triangles_batch = ( jax.lax.dynamic_slice_in_dim( - active_triangles, start_index, batch_size, axis=-1 + active_triangles, start_index, ray_chunk_size, axis=0 ) if active_triangles is not None else None ) - indices, t, center_distances = map_fn( - ray_origins, - ray_directions, - batch_of_triangle_vertices, - batch_of_active_triangles, + indices_batch, t_batch = _process_ray_batch( + ray_origins_batch, + ray_directions_batch, + triangle_vertices_batch, + active_triangles_batch, ) - # TODO: use *carry when ty supports starred expressions - return reduce_fn( - (carry[0], carry[1], carry[2], epsilon), - (indices + start_index, t, center_distances, epsilon), - )[:3] - - init_val = ( - -jnp.ones(batch, dtype=jnp.int32), - jnp.full( - batch, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ), - jnp.full( - batch, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ), - ) + indices = jax.lax.dynamic_update_slice(carry[0], indices_batch, (start_index,)) + t = jax.lax.dynamic_update_slice(carry[1], t_batch, (start_index,)) + return (indices, t) - indices, t, center_distances = jax.lax.fori_loop( + indices, t = jax.lax.fori_loop( 0, - num_batches, - body_fun, - init_val=init_val, + num_ray_batches, + ray_body_fun, + init_val=(init_indices, init_t), ) - if rem > 0: - rem_indices, rem_t, rem_center_distances = map_fn( - ray_origins, - ray_directions, - triangle_vertices[..., -rem:, :, :], - active_triangles[..., -rem:] if active_triangles is not None else None, + if rem_rays > 0: + start_index = num_ray_batches * ray_chunk_size + ray_origins_batch = ray_origins[-rem_rays:, :] + ray_directions_batch = ray_directions[-rem_rays:, :] + triangle_vertices_batch = triangle_vertices[-rem_rays:, :, :, :] + active_triangles_batch = ( + active_triangles[-rem_rays:, :] if active_triangles is not None else None ) - return reduce_fn( - (indices, t, center_distances, epsilon), - ( - rem_indices + num_batches * batch_size, - rem_t, - rem_center_distances, - epsilon, - ), - )[:2] - return (indices, t) + rem_indices, rem_t = _process_ray_batch( + ray_origins_batch, + ray_directions_batch, + triangle_vertices_batch, + active_triangles_batch, + ) + indices = jax.lax.dynamic_update_slice(indices, rem_indices, (start_index,)) + t = jax.lax.dynamic_update_slice(t, rem_t, (start_index,)) + + return (indices.reshape(batch), t.reshape(batch)) From 6b072773d7aac04784d3c2b7f1e48d690fc250ff Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 10:56:20 +0200 Subject: [PATCH 02/13] fix(docs): update batch_size argument semantics in ray-triangle functions --- CHANGELOG.md | 4 ++++ differt/src/differt/rt/_utils.py | 35 +++++++++++++++++++------------- 2 files changed, 25 insertions(+), 14 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 5d8eb564..6e456606 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,6 +22,10 @@ with one *slight* but **important** difference: ## [Unreleased](https://github.com/jeertmans/DiffeRT/compare/v0.8.1...HEAD) +### Changed + +- Updated the `batch_size` argument semantics for {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays`: `None` now means the explicit `ray_batch_size` and `tri_batch_size` values are used, while any other value is propagated to both batch-size arguments (by , in ). + ### Added - Improved Sionna-compatible XML scene parser to support top-level `` materials in addition to nested structures, enabling support for OSM buildings and other XML formats (by , in ). diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 718e9ba2..669c1c10 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -420,24 +420,26 @@ def rays_intersect_any_triangle( between 0 (:data:`False`) and 1 (:data:`True`). For more details, refer to :ref:`smoothing`. - batch_size: The default batch size used when either the ray or triangle batch size is - not specified. This allows to make a trade-off between memory usage and performance. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. - If :data:`None`, a heuristic based on the input sizes is used. + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are used. + Otherwise, they are both overwritten with this value. ray_batch_size: The number of rays to process in a single batch. This allows to make a trade-off between memory usage and performance. The ray batch size is automatically adjusted to be the minimum of the number of rays and the specified ray batch size. - If :data:`None`, the ray batch size is set to the number of rays. + If :data:`None`, it defaults to :data:`batch_size`. tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. The triangle batch size is automatically adjusted to be the minimum of the number of triangles and the specified triangle batch size. - If :data:`None`, the triangle batch size defaults to :data:`batch_size`. + If :data:`None`, it defaults to :data:`batch_size`. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -548,6 +550,7 @@ def rays_intersect_any_triangle( if has_batch_tris: xs_rays.append(blocked_tris) if active_triangles is not None and has_batch_active: + # TODO: fix type checking issue here, as blocked_active is inferred as Array | None, but we know it's not None in this branch... do we? xs_rays.append(blocked_active) def scan_rays(carry_rays, ray_chunk): @@ -642,24 +645,26 @@ def triangles_visible_from_vertices( num_rays: The number of rays to launch. The larger, the more accurate. - batch_size: The default batch size used when either the ray or triangle batch size is - not specified. This allows to make a trade-off between memory usage and performance. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. - If :data:`None`, a heuristic based on the input sizes is used. + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are used. + Otherwise, they are both overwritten with this value. ray_batch_size: The number of rays to process in a single batch. This allows to make a trade-off between memory usage and performance. The ray batch size is automatically adjusted to be the minimum of the number of rays and the specified ray batch size. - If :data:`None`, the ray batch size is set to the number of rays. + If :data:`None`, it defaults to ``batch_size``. tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. The triangle batch size is automatically adjusted to be the minimum of the number of triangles and the specified triangle batch size. - If :data:`None`, the triangle batch size defaults to :data:`batch_size`. + If :data:`None`, it defaults to :data:`batch_size`. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -881,14 +886,16 @@ def first_triangles_hit_by_rays( which triangles are active, i.e., should be considered for intersection. If not specified, all triangles are considered active. - batch_size: The default batch size used when either the ray or triangle batch size is - not specified. This allows to make a trade-off between memory usage and performance. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. - If :data:`None`, a heuristic based on the input sizes is used. + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are used. + Otherwise, they are both overwritten with this value. ray_batch_size: The number of rays to process in a single batch. This allows to chunk rays and reduce peak memory usage. - If :data:`None`, all rays are processed together. + If :data:`None`, it defaults to ``batch_size``. tri_batch_size: The number of triangles to process in a single batch. This allows to chunk triangles and reduce peak memory usage. From 7b7e1971473db78b9793160d4da2b1d88dbcfb11 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 11:04:35 +0200 Subject: [PATCH 03/13] chore(docs): update changelog --- CHANGELOG.md | 13 +++++++++---- 1 file changed, 9 insertions(+), 4 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 6e456606..1167df65 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -22,19 +22,24 @@ with one *slight* but **important** difference: ## [Unreleased](https://github.com/jeertmans/DiffeRT/compare/v0.8.1...HEAD) -### Changed - -- Updated the `batch_size` argument semantics for {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays`: `None` now means the explicit `ray_batch_size` and `tri_batch_size` values are used, while any other value is propagated to both batch-size arguments (by , in ). - ### Added - Improved Sionna-compatible XML scene parser to support top-level `` materials in addition to nested structures, enabling support for OSM buildings and other XML formats (by , in ). - Added fallback to black color `[0.0, 0.0, 0.0]` when material `` elements are missing, with appropriate warnings logged (by , in ). +- Added the `ray_batch_size` and `tri_batch_size` arguments to {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays` to let users control ray and triangle chunk sizes independently (by , in ). + +### Changed + +- Updated the `batch_size` argument semantics for {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays`: `None` now means the explicit `ray_batch_size` and `tri_batch_size` values are used, while any other value is propagated to both batch-size arguments (by , in ). ### Chore - Added tests for the improved Sionna-compatible XML scene parser using OSM building data, ensuring correct parsing of materials and colors (by , in ). +### Perf + +- Reworked the ray-triangle intersection helpers to use double-chunk processing over rays and triangles, reducing peak memory usage while keeping performance high on CPU and GPU (by , in ). + ## [0.8.1](https://github.com/jeertmans/DiffeRT/compare/v0.8.0...v0.8.1) ### Changed From fa073382aed15fe32690f3336fcccc46e6217226 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 11:28:56 +0200 Subject: [PATCH 04/13] fix(lib): apply review suggestions --- CHANGELOG.md | 2 +- differt/src/differt/rt/_utils.py | 64 ++++++++++++++++++++++---------- 2 files changed, 45 insertions(+), 21 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 1167df65..33a2522b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -30,7 +30,7 @@ with one *slight* but **important** difference: ### Changed -- Updated the `batch_size` argument semantics for {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays`: `None` now means the explicit `ray_batch_size` and `tri_batch_size` values are used, while any other value is propagated to both batch-size arguments (by , in ). +- Updated the `batch_size` argument semantics for {func}`rays_intersect_any_triangle`, {func}`triangles_visible_from_vertices`, and {func}`first_triangles_hit_by_rays` so that `batch_size` acts as the default for whichever of `ray_batch_size` and `tri_batch_size` are left unspecified (by , in ). ### Chore diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 669c1c10..bc688041 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -369,6 +369,7 @@ def rays_intersect_any_triangle( smoothing_factor: Float[ArrayLike, ""], batch_size: int | None = ..., ray_batch_size: int | None = ..., + tri_batch_size: int | None = ..., **kwargs: Any, ) -> Float[Array, " *batch"]: ... @@ -424,8 +425,9 @@ def rays_intersect_any_triangle( ``tri_batch_size`` is not specified. This allows to make a trade-off between memory usage and performance. - If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are used. - Otherwise, they are both overwritten with this value. + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. ray_batch_size: The number of rays to process in a single batch. This allows to make a trade-off between memory usage and performance. @@ -550,10 +552,12 @@ def rays_intersect_any_triangle( if has_batch_tris: xs_rays.append(blocked_tris) if active_triangles is not None and has_batch_active: - # TODO: fix type checking issue here, as blocked_active is inferred as Array | None, but we know it's not None in this branch... do we? - xs_rays.append(blocked_active) + xs_rays.append(typing.cast("Array", blocked_active)) - def scan_rays(carry_rays, ray_chunk): + def scan_rays( + carry_rays: Any, + ray_chunk: tuple[Array, ...], + ) -> tuple[Any, Array]: ro_block = ray_chunk[0] rd_block = ray_chunk[1] @@ -575,7 +579,10 @@ def scan_rays(carry_rays, ray_chunk): if active_block_batch is not None: xs_tris.append(active_block_batch) - def scan_tris(carry_tris, tris_chunk): + def scan_tris( + carry_tris: Any, + tris_chunk: tuple[Array, ...], + ) -> tuple[Any, None]: tris_block = tris_chunk[0] active_block = tris_chunk[1] if len(tris_chunk) > 1 else None @@ -649,8 +656,9 @@ def triangles_visible_from_vertices( ``tri_batch_size`` is not specified. This allows to make a trade-off between memory usage and performance. - If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are used. - Otherwise, they are both overwritten with this value. + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. ray_batch_size: The number of rays to process in a single batch. This allows to make a trade-off between memory usage and performance. @@ -799,13 +807,16 @@ def triangles_visible_from_vertices( ray_batch_size = max(min(ray_batch_size, num_rays), 1) tri_batch_size = max(min(tri_batch_size, num_triangles), 1) + if num_triangles == 0: + return jnp.zeros((*batch, 0), dtype=jnp.bool_) + num_ray_batches, rem_rays = divmod(num_rays, ray_batch_size) def update_visible_triangles( visible_triangles: Bool[Array, "*#batch num_triangles"], ray_directions_batch: Float[Array, "*#batch batch_rays 3"], ) -> Bool[Array, "*#batch num_triangles"]: - """Check which triangles are visible from rays in this batch.""" + # Check which triangles are visible from rays in this batch. indices, _ = first_triangles_hit_by_rays( ray_origins[..., None, :], ray_directions_batch, @@ -819,14 +830,26 @@ def update_visible_triangles( **kwargs, ) # indices: [*batch ray_batch_size], value >= 0 means triangle index was hit - # Convert to per-triangle: for each triangle, check if any ray hit it - indices_expanded = indices[..., None] # [*batch ray_batch_size 1] - triangles_range = jnp.arange(num_triangles) # [num_triangles] - indices_one_hot = (indices_expanded == triangles_range) & ( - indices_expanded >= 0 - ) - # Reduce over rays: [*batch num_triangles] - hit_any_ray = jnp.any(indices_one_hot, axis=-2) + # Convert to per-triangle using a bincount-based reduction to avoid materializing a + # [*batch ray_batch_size num_triangles] one-hot tensor. + valid_hits = indices >= 0 + safe_indices = jnp.where(valid_hits, indices, 0) + + flat_safe_indices = safe_indices.reshape(-1, safe_indices.shape[-1]) + flat_valid_hits = valid_hits.reshape(-1, valid_hits.shape[-1]) + + def count_hits(row_indices: Array, row_valid_hits: Array) -> Array: + return ( + jnp.bincount( + row_indices, + weights=row_valid_hits.astype(jnp.int32), + length=num_triangles, + ) + > 0 + ) + + hit_any_ray = jax.vmap(count_hits)(flat_safe_indices, flat_valid_hits) + hit_any_ray = hit_any_ray.reshape((*indices.shape[:-1], num_triangles)) return visible_triangles | hit_any_ray def body_fun( @@ -890,8 +913,9 @@ def first_triangles_hit_by_rays( ``tri_batch_size`` is not specified. This allows to make a trade-off between memory usage and performance. - If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are used. - Otherwise, they are both overwritten with this value. + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. ray_batch_size: The number of rays to process in a single batch. This allows to chunk rays and reduce peak memory usage. @@ -998,7 +1022,7 @@ def _process_ray_batch( triangle_vertices_batch: Float[Array, "*#batch num_triangles 3 3"], active_triangles_batch: Bool[Array, "*#batch num_triangles"] | None = None, ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: - """Process one batch of rays through all triangles.""" + # Process one batch of rays through all triangles. def map_fn( ray_origins: Float[Array, "*#batch 3"], From ac23eb3674bf19d9cd1dd6ed11d02b55b6aba456 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 11:38:05 +0200 Subject: [PATCH 05/13] fix(lib): batch size selection logic --- differt/src/differt/rt/_utils.py | 65 +++++++++++++++++++++----------- 1 file changed, 43 insertions(+), 22 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index bc688041..12b5e815 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -471,14 +471,20 @@ def rays_intersect_any_triangle( ) num_rays = math.prod(batch) - # If user passed None explicitly, process everything in a single batch - if batch_size is None: - batch_size = max(num_rays, num_triangles) - if ray_batch_size is None: - ray_batch_size = batch_size - if tri_batch_size is None: - tri_batch_size = batch_size - + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays + ) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles + ) ray_chunk_size = max(min(ray_batch_size, num_rays), 1) tri_chunk_size = max(min(tri_batch_size, num_triangles), 1) @@ -797,13 +803,21 @@ def triangles_visible_from_vertices( ) num_triangles = triangle_vertices.shape[-3] - # None means process everything in a single batch - if batch_size is None: - batch_size = max(num_rays, num_triangles) - if ray_batch_size is None: - ray_batch_size = batch_size - if tri_batch_size is None: - tri_batch_size = batch_size + num_rays = math.prod(batch) + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays + ) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles + ) ray_batch_size = max(min(ray_batch_size, num_rays), 1) tri_batch_size = max(min(tri_batch_size, num_triangles), 1) @@ -957,13 +971,20 @@ def first_triangles_hit_by_rays( ) num_rays = math.prod(batch) - # None means process everything in a single batch - if batch_size is None: - batch_size = max(num_rays, num_triangles) - if ray_batch_size is None: - ray_batch_size = batch_size - if tri_batch_size is None: - tri_batch_size = batch_size + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays + ) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles + ) ray_chunk_size = max(min(ray_batch_size, num_rays), 1) tri_chunk_size = max(min(tri_batch_size, num_triangles), 1) From 0b97994cf4547aa2e331d37eb705d02fcb8aac58 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 11:57:09 +0200 Subject: [PATCH 06/13] fix(docs): typo --- differt/src/differt/rt/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 12b5e815..2862e9e2 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -678,7 +678,7 @@ def triangles_visible_from_vertices( The triangle batch size is automatically adjusted to be the minimum of the number of triangles and the specified triangle batch size. - If :data:`None`, it defaults to :data:`batch_size`. + If :data:`None`, it defaults to ``batch_size``. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -937,7 +937,7 @@ def first_triangles_hit_by_rays( tri_batch_size: The number of triangles to process in a single batch. This allows to chunk triangles and reduce peak memory usage. - If :data:`None`, the triangle batch size defaults to :data:`batch_size`. + If :data:`None`, the triangle batch size defaults to ``batch_size``. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. From 0aad6e2cb1b5f6b1531d3057b80caac0e7e6468b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 18:59:42 +0200 Subject: [PATCH 07/13] refactor(lib): improve functions (wip) --- differt/src/differt/rt/_utils.py | 943 +++++++++++++++++-------------- 1 file changed, 507 insertions(+), 436 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 2862e9e2..e35c8e5c 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -1,7 +1,7 @@ import math import typing from collections.abc import Callable, Iterator, Sized -from functools import cache +from functools import cache, partial from typing import TYPE_CHECKING, Any, TypeVar, overload import equinox as eqx @@ -9,7 +9,7 @@ import jax.numpy as jnp from jaxtyping import Array, ArrayLike, Bool, Float, Int -from differt.geometry import fibonacci_lattice, viewing_frustum +from differt.geometry import viewing_frustum from differt.utils import smoothing_function from differt_core.rt import CompleteGraph @@ -342,6 +342,235 @@ def rays_intersect_triangles( return t, hit +def _ray_intersect_any_triangle_batched( + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, + *, + hit_threshold: Float[Array, ""] | None, + smoothing_factor: Float[ArrayLike, ""] | None, + **kwargs: Any, +) -> Bool[Array, ""] | Float[Array, ""]: + ts, hits = jax.vmap( + partial(rays_intersect_triangles, smoothing_factor=smoothing_factor, **kwargs), + in_axes=(None, None, 0), + )(ray_origin, ray_direction, triangle_vertices) + if smoothing_factor is not None: + return jnp.minimum( + hits, smoothing_function(hit_threshold - ts, smoothing_factor) + ).max(axis=-1, where=active_triangles) + return ((ts < hit_threshold) & hits).any(axis=-1, where=active_triangles) + + +def _ray_intersect_any_triangle( + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, + *, + hit_threshold: Float[Array, ""] | None, + smoothing_factor: Float[ArrayLike, ""] | None, + batch_size: int, + **kwargs: Any, +) -> Bool[Array, ""] | Float[Array, ""]: + def reduce_fn( + left: Bool[Array, ""] | Float[Array, ""], + right: Bool[Array, ""] | Float[Array, ""], + ) -> Bool[Array, ""] | Float[Array, ""]: + if smoothing_factor is not None: + return jnp.maximum(left, right) + return left | right + + def body_fn( + batch_index: Int[Array, ""], + intersect_so_far: Bool[Array, ""] | Float[Array, ""], + ) -> Bool[Array, ""] | Float[Array, ""]: + start_index = batch_index * batch_size + batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( + triangle_vertices, start_index, batch_size, axis=0 + ) + batch_of_active_triangles = ( + jax.lax.dynamic_slice_in_dim( + active_triangles, start_index, batch_size, axis=0 + ) + if active_triangles is not None + else None + ) + intersect_in_batch = _ray_intersect_any_triangle_batched( + ray_origin, + ray_direction, + batch_of_triangle_vertices, + batch_of_active_triangles, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + **kwargs, + ) + + return reduce_fn(intersect_so_far, intersect_in_batch) + + num_triangles = triangle_vertices.shape[0] + num_batches, rem = divmod(num_triangles, batch_size) + + init_val = 0.0 if smoothing_factor is not None else False + + intersect = jax.lax.fori_loop( + 0, + num_batches, + body_fn, + init_val=init_val, + ) + + if rem > 0: + intersect = reduce_fn( + intersect, + _ray_intersect_any_triangle_batched( + ray_origin, + ray_direction, + triangle_vertices[-rem:, :], + (active_triangles[-rem:] if active_triangles is not None else None), + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + **kwargs, + ), + ) + + return intersect + + +def _first_triangle_hit_by_ray_batched( + ray_origin, + ray_direction, + triangle_vertices, + active_triangles, + triangle_indices, + *, + dist_tol, + **kwargs, +): + ts, hits = jax.vmap( + partial(rays_intersect_triangles, **kwargs), + in_axes=(None, None, 0), + )(ray_origin, ray_direction, triangle_vertices) + + if active_triangles is not None: + hits &= active_triangles + + ts = jnp.where(hits, ts, jnp.inf) + min_t = jnp.min(ts) + + center_distances = jnp.linalg.norm( + triangle_vertices.mean(axis=-2) - ray_origin, axis=-1 + ) + + is_close = jnp.abs(ts - min_t) < dist_tol + dist_to_check = jnp.where(is_close, center_distances, jnp.inf) + min_dist = jnp.min(dist_to_check) + + is_best = (dist_to_check == min_dist) & is_close + best_idx = jnp.argmax(is_best) + + any_hit = jnp.any(hits) + + return ( + jnp.where(any_hit, triangle_indices[best_idx], -1), + jnp.where(any_hit, min_t, jnp.inf), + jnp.where(any_hit, min_dist, jnp.inf), + ) + + +def _first_triangle_hit_by_ray( + ray_origin, + ray_direction, + triangle_vertices, + active_triangles, + *, + batch_size, + dist_tol, + **kwargs, +): + def combine_best( + left: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], + right: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], + ) -> tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]]: + idx1, t1, d1 = left + idx2, t2, d2 = right + + cond = jnp.where( + jnp.abs(t1 - t2) < dist_tol, + d1 < d2, + t1 < t2, + ) + + return ( + jnp.where(cond, idx1, idx2), + jnp.where(cond, t1, t2), + jnp.where(cond, d1, d2), + ) + + def body_fn( + batch_index: Int[Array, ""], + best_so_far: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], + ) -> tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]]: + start_index = batch_index * batch_size + batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( + triangle_vertices, start_index, batch_size, axis=0 + ) + batch_of_active_triangles = ( + jax.lax.dynamic_slice_in_dim( + active_triangles, start_index, batch_size, axis=0 + ) + if active_triangles is not None + else None + ) + batch_of_indices = jnp.arange(batch_size, dtype=jnp.int32) + start_index + + best_in_batch = _first_triangle_hit_by_ray_batched( + ray_origin, + ray_direction, + batch_of_triangle_vertices, + batch_of_active_triangles, + batch_of_indices, + dist_tol=dist_tol, + **kwargs, + ) + + return combine_best(best_so_far, best_in_batch) + + num_triangles = triangle_vertices.shape[0] + num_batches, rem = divmod(num_triangles, batch_size) + + init_val = ( + jnp.array(-1, dtype=jnp.int32), + jnp.array(jnp.inf, dtype=ray_origin.dtype), + jnp.array(jnp.inf, dtype=ray_origin.dtype), + ) + + best = jax.lax.fori_loop( + 0, + num_batches, + body_fn, + init_val=init_val, + ) + + if rem > 0: + start_index = num_batches * batch_size + best = combine_best( + best, + _first_triangle_hit_by_ray_batched( + ray_origin, + ray_direction, + triangle_vertices[-rem:, :], + (active_triangles[-rem:] if active_triangles is not None else None), + jnp.arange(rem, dtype=jnp.int32) + start_index, + dist_tol=dist_tol, + **kwargs, + ), + ) + + return best[0], best[1] + + @overload def rays_intersect_any_triangle( ray_origins: Float[ArrayLike, "*#batch 3"], @@ -469,7 +698,6 @@ def rays_intersect_any_triangle( triangle_vertices.shape[:-3], active_triangles.shape[:-1] if active_triangles is not None else (), ) - num_rays = math.prod(batch) ray_batch_size = ( ray_batch_size @@ -478,6 +706,7 @@ def rays_intersect_any_triangle( if batch_size is not None else num_rays ) + ray_batch_size = min(ray_batch_size, num_rays) tri_batch_size = ( tri_batch_size if tri_batch_size is not None @@ -485,9 +714,7 @@ def rays_intersect_any_triangle( if batch_size is not None else num_triangles ) - ray_chunk_size = max(min(ray_batch_size, num_rays), 1) - tri_chunk_size = max(min(tri_batch_size, num_triangles), 1) - + tri_batch_size = min(tri_batch_size, num_triangles) if num_triangles == 0: # If there are no triangles, there are no intersections return ( @@ -499,128 +726,164 @@ def rays_intersect_any_triangle( else jnp.zeros(batch, dtype=bool) ) - ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)).reshape(-1, 3) - ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)).reshape(-1, 3) - - pad_rays_len = (ray_chunk_size - (num_rays % ray_chunk_size)) % ray_chunk_size - if pad_rays_len > 0: - ray_origins = jnp.pad(ray_origins, ((0, pad_rays_len), (0, 0))) - ray_directions = jnp.pad(ray_directions, ((0, pad_rays_len), (0, 0))) - - num_rays_chunks = ray_origins.shape[0] // ray_chunk_size - blocked_ro = ray_origins.reshape(num_rays_chunks, ray_chunk_size, 3) - blocked_rd = ray_directions.reshape(num_rays_chunks, ray_chunk_size, 3) - - pad_tris_len = (tri_chunk_size - (num_triangles % tri_chunk_size)) % tri_chunk_size - - has_batch_tris = triangle_vertices.shape[:-3] != () - if has_batch_tris: - triangle_vertices = jnp.broadcast_to( - triangle_vertices, (*batch, num_triangles, 3, 3) - ).reshape(-1, num_triangles, 3, 3) - if pad_rays_len > 0 or pad_tris_len > 0: - triangle_vertices = jnp.pad( - triangle_vertices, - ((0, pad_rays_len), (0, pad_tris_len), (0, 0), (0, 0)), - ) - blocked_tris = triangle_vertices.reshape( - num_rays_chunks, ray_chunk_size, -1, tri_chunk_size, 3, 3 + if num_rays > ray_batch_size: + xs = [] + argnames = [] + map_fn = partial( + _ray_intersect_any_triangle, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + batch_size=tri_batch_size, + **kwargs, ) - else: - if pad_tris_len > 0: - triangle_vertices = jnp.pad( - triangle_vertices, ((0, pad_tris_len), (0, 0), (0, 0)) + if math.prod(ray_origins.shape[:-1]) > 1: + ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)) + xs.append(ray_origins.reshape(-1, 3)) + argnames.append("ray_origin") + else: + map_fn = partial(map_fn, ray_origin=ray_origins) + if math.prod(ray_directions.shape[:-1]) > 1: + ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)) + xs.append(ray_directions.reshape(-1, 3)) + argnames.append("ray_direction") + else: + map_fn = partial(map_fn, ray_direction=ray_directions) + if math.prod(triangle_vertices.shape[:-3]) > 1: + triangle_vertices = jnp.broadcast_to( + triangle_vertices, (*batch, num_triangles, 3, 3) ) - blocked_tris = triangle_vertices.reshape(-1, tri_chunk_size, 3, 3) - - if active_triangles is not None: - has_batch_active = active_triangles.shape[:-1] != () - if has_batch_active: + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial(map_fn, triangle_vertices=triangle_vertices) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: active_triangles = jnp.broadcast_to( active_triangles, (*batch, num_triangles) - ).reshape(-1, num_triangles) - if pad_rays_len > 0 or pad_tris_len > 0: - active_triangles = jnp.pad( - active_triangles, ((0, pad_rays_len), (0, pad_tris_len)) - ) - blocked_active = active_triangles.reshape( - num_rays_chunks, ray_chunk_size, -1, tri_chunk_size ) + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") else: - if pad_tris_len > 0: - active_triangles = jnp.pad(active_triangles, ((0, pad_tris_len),)) - blocked_active = active_triangles.reshape(-1, tri_chunk_size) - else: - has_batch_active = False - blocked_active = None - - xs_rays = [blocked_ro, blocked_rd] - if has_batch_tris: - xs_rays.append(blocked_tris) - if active_triangles is not None and has_batch_active: - xs_rays.append(typing.cast("Array", blocked_active)) - - def scan_rays( - carry_rays: Any, - ray_chunk: tuple[Array, ...], - ) -> tuple[Any, Array]: - ro_block = ray_chunk[0] - rd_block = ray_chunk[1] - - idx = 2 - if has_batch_tris: - tris_block_batch = jnp.swapaxes(ray_chunk[idx], 0, 1) - idx += 1 - else: - tris_block_batch = blocked_tris + map_fn = partial(map_fn, active_triangles=active_triangles) - if active_triangles is not None and has_batch_active: - active_block_batch = jnp.swapaxes(ray_chunk[idx], 0, 1) - elif active_triangles is not None: - active_block_batch = blocked_active - else: - active_block_batch = None - - xs_tris = [tris_block_batch] - if active_block_batch is not None: - xs_tris.append(active_block_batch) - - def scan_tris( - carry_tris: Any, - tris_chunk: tuple[Array, ...], - ) -> tuple[Any, None]: - tris_block = tris_chunk[0] - active_block = tris_chunk[1] if len(tris_chunk) > 1 else None - - t, hit = rays_intersect_triangles( - ro_block[..., None, :], - rd_block[..., None, :], - tris_block, - smoothing_factor=smoothing_factor, - **kwargs, - ) + def f(args): + return map_fn(**dict(zip(argnames, args, strict=True))) + + return jax.lax.map(f, tuple(xs), batch_size=ray_batch_size).reshape(batch) + return jnp.vectorize( + partial( + _ray_intersect_any_triangle, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + batch_size=tri_batch_size, + **kwargs, + ), + signature="(3),(3),(n,3,3),(n)->()" + if active_triangles is not None + else "(3),(3),(n,3,3),()->()", + )(ray_origins, ray_directions, triangle_vertices, active_triangles) - if smoothing_factor is not None: - block_hits = jnp.minimum( - hit, smoothing_function(hit_threshold - t, smoothing_factor) - ).sum(axis=-1, where=active_block) - return (carry_tris + block_hits).clip(max=1.0), None - block_hits = ((t < hit_threshold) & hit).any(axis=-1, where=active_block) - return carry_tris | block_hits, None - init_val = ( - jnp.zeros(ray_chunk_size) - if smoothing_factor is not None - else jnp.zeros(ray_chunk_size, dtype=bool) +def _fibonacci_lattice_chunk( + start_index, + chunk_size, + total_n, + frustum, +): + phi = 1.618033988749895 # golden ratio + i = jnp.arange(chunk_size) + start_index + + lat = jnp.arccos(1 - 2 * i / total_n) + lon = 2 * jnp.pi * i / phi + + pa = jnp.stack((lat, lon), axis=-1) + + if frustum is not None: + pa %= frustum[1, -2:] - frustum[0, -2:] + pa += frustum[0, -2:] + + from differt.geometry import spherical_to_cartesian + + return spherical_to_cartesian(pa) + + +def _triangles_visible_from_vertex( + vertex, + triangle_vertices, + active_triangles, + *, + num_rays, + num_triangles, + ray_batch_size, + tri_batch_size, + world_vertices, + active_vertices, + **kwargs, +): + frustum = viewing_frustum(vertex, world_vertices, active_vertices=active_vertices) + + def body_fn( + batch_index: Int[Array, ""], + visible_so_far: Bool[Array, " num_triangles"], + ) -> Bool[Array, " num_triangles"]: + start_index = batch_index * ray_batch_size + ray_directions = _fibonacci_lattice_chunk( + start_index, ray_batch_size, num_rays, frustum ) - hits_for_chunk, _ = jax.lax.scan(scan_tris, init=init_val, xs=tuple(xs_tris)) - return carry_rays, hits_for_chunk + indices, _ = first_triangles_hit_by_rays( + vertex, + ray_directions, + triangle_vertices, + active_triangles=active_triangles, + ray_batch_size=None, + tri_batch_size=tri_batch_size, + **kwargs, + ) - _, all_hits = jax.lax.scan(scan_rays, init=None, xs=tuple(xs_rays)) + valid_hits = indices >= 0 + safe_indices = jnp.where(valid_hits, indices, 0) + visible_in_batch = ( + jnp.zeros(num_triangles, dtype=bool).at[safe_indices].max(valid_hits) + ) - hits = all_hits.reshape(-1) - return hits[:num_rays].reshape(batch) + return visible_so_far | visible_in_batch + + num_ray_batches, rem_rays = divmod(num_rays, ray_batch_size) + + visible = jax.lax.fori_loop( + 0, + num_ray_batches, + body_fn, + init_val=jnp.zeros(num_triangles, dtype=bool), + ) + + if rem_rays > 0: + start_index = num_ray_batches * ray_batch_size + ray_directions = _fibonacci_lattice_chunk( + start_index, rem_rays, num_rays, frustum + ) + indices, _ = first_triangles_hit_by_rays( + vertex, + ray_directions, + triangle_vertices, + active_triangles=active_triangles, + ray_batch_size=None, + tri_batch_size=tri_batch_size, + **kwargs, + ) + valid_hits = indices >= 0 + safe_indices = jnp.where(valid_hits, indices, 0) + visible_in_batch = ( + jnp.bincount( + safe_indices, + weights=valid_hits.astype(jnp.int32), + length=num_triangles, + ) + > 0 + ) + visible |= visible_in_batch + + return visible @eqx.filter_jit @@ -766,10 +1029,8 @@ def triangles_visible_from_vertices( """ vertices = jnp.asarray(vertices) triangle_vertices = jnp.asarray(triangle_vertices) - triangle_centers = triangle_vertices.mean(axis=-2, keepdims=True) - world_vertices = jnp.concat((triangle_vertices, triangle_centers), axis=-2).reshape( - *triangle_vertices.shape[:-3], -1, 3 - ) + + num_triangles = triangle_vertices.shape[-3] if active_triangles is not None: active_triangles = jnp.asarray(active_triangles) @@ -777,33 +1038,18 @@ def triangles_visible_from_vertices( else: active_vertices = None - # [*batch 3] - ray_origins = vertices - - # [*batch 2 3] - frustum = viewing_frustum( - ray_origins, - world_vertices, - active_vertices=active_vertices, + triangle_centers = triangle_vertices.mean(axis=-2, keepdims=True) + world_vertices = jnp.concat((triangle_vertices, triangle_centers), axis=-2).reshape( + *triangle_vertices.shape[:-3], -1, 3 ) - # [*batch num_rays 3] - ray_directions = jnp.vectorize( - lambda n, frustum: fibonacci_lattice(n, frustum=frustum), - excluded={0}, - signature="(2,3)->(n,3)", - )(num_rays, frustum) - - # Combine the batch dimensions batch = jnp.broadcast_shapes( - ray_origins.shape[:-2], - ray_directions.shape[:-2], + vertices.shape[:-1], triangle_vertices.shape[:-3], active_triangles.shape[:-1] if active_triangles is not None else (), ) + num_vertices = math.prod(batch) - num_triangles = triangle_vertices.shape[-3] - num_rays = math.prod(batch) ray_batch_size = ( ray_batch_size if ray_batch_size is not None @@ -811,6 +1057,7 @@ def triangles_visible_from_vertices( if batch_size is not None else num_rays ) + ray_batch_size = min(ray_batch_size, num_rays) tri_batch_size = ( tri_batch_size if tri_batch_size is not None @@ -818,80 +1065,85 @@ def triangles_visible_from_vertices( if batch_size is not None else num_triangles ) - ray_batch_size = max(min(ray_batch_size, num_rays), 1) - tri_batch_size = max(min(tri_batch_size, num_triangles), 1) + tri_batch_size = min(tri_batch_size, num_triangles) if num_triangles == 0: return jnp.zeros((*batch, 0), dtype=jnp.bool_) - num_ray_batches, rem_rays = divmod(num_rays, ray_batch_size) - - def update_visible_triangles( - visible_triangles: Bool[Array, "*#batch num_triangles"], - ray_directions_batch: Float[Array, "*#batch batch_rays 3"], - ) -> Bool[Array, "*#batch num_triangles"]: - # Check which triangles are visible from rays in this batch. - indices, _ = first_triangles_hit_by_rays( - ray_origins[..., None, :], - ray_directions_batch, - triangle_vertices[..., None, :, :, :], - active_triangles=active_triangles[..., None, :] + xs = [] + argnames = [] + map_fn = partial( + _triangles_visible_from_vertex, + num_rays=num_rays, + num_triangles=num_triangles, + ray_batch_size=ray_batch_size, + tri_batch_size=tri_batch_size, + **kwargs, + ) + if math.prod(vertices.shape[:-1]) > 1: + xs.append(vertices.reshape(-1, 3)) + argnames.append("vertex") + else: + map_fn = partial(map_fn, vertex=vertices.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial( + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) + ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) if active_triangles is not None else None, - batch_size=tri_batch_size, - ray_batch_size=ray_directions_batch.shape[-2], - tri_batch_size=tri_batch_size, - **kwargs, ) - # indices: [*batch ray_batch_size], value >= 0 means triangle index was hit - # Convert to per-triangle using a bincount-based reduction to avoid materializing a - # [*batch ray_batch_size num_triangles] one-hot tensor. - valid_hits = indices >= 0 - safe_indices = jnp.where(valid_hits, indices, 0) - flat_safe_indices = safe_indices.reshape(-1, safe_indices.shape[-1]) - flat_valid_hits = valid_hits.reshape(-1, valid_hits.shape[-1]) - - def count_hits(row_indices: Array, row_valid_hits: Array) -> Array: - return ( - jnp.bincount( - row_indices, - weights=row_valid_hits.astype(jnp.int32), - length=num_triangles, - ) - > 0 - ) - - hit_any_ray = jax.vmap(count_hits)(flat_safe_indices, flat_valid_hits) - hit_any_ray = hit_any_ray.reshape((*indices.shape[:-1], num_triangles)) - return visible_triangles | hit_any_ray + # Also need to handle world_vertices and active_vertices broadcasting + if math.prod(world_vertices.shape[:-2]) > 1: + xs.append(world_vertices.reshape(-1, *world_vertices.shape[-2:])) + argnames.append("world_vertices") + else: + map_fn = partial(map_fn, world_vertices=world_vertices.reshape(-1, 3)) - def body_fun( - batch_index: Int[Array, ""], - visible_triangles: Bool[Array, "*batch num_triangles"], - ) -> Bool[Array, "*batch num_triangles"]: - start_index = batch_index * ray_batch_size - batch_of_ray_directions = jax.lax.dynamic_slice_in_dim( - ray_directions, start_index, ray_batch_size, axis=-2 + if active_vertices is not None and math.prod(active_vertices.shape[:-1]) > 1: + xs.append(active_vertices.reshape(-1, active_vertices.shape[-1])) + argnames.append("active_vertices") + else: + map_fn = partial( + map_fn, + active_vertices=active_vertices.reshape(-1) + if active_vertices is not None + else None, ) - return update_visible_triangles(visible_triangles, batch_of_ray_directions) - - init_val = jnp.zeros((*batch, num_triangles), dtype=jnp.bool_) - visible_triangles = jax.lax.fori_loop( - 0, - num_ray_batches, - body_fun, - init_val=init_val, - ) + def f(args): + return map_fn(**dict(zip(argnames, args, strict=True))) - if rem_rays > 0: - batch_of_ray_directions = ray_directions[..., -rem_rays:, :] - visible_triangles = update_visible_triangles( - visible_triangles, batch_of_ray_directions - ) + if not xs: + # Case where everything is shared or num_vertices == 1 + return _triangles_visible_from_vertex( + vertices.reshape(3), + triangle_vertices.reshape(num_triangles, 3, 3), + active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + num_rays=num_rays, + num_triangles=num_triangles, + ray_batch_size=ray_batch_size, + tri_batch_size=tri_batch_size, + world_vertices=world_vertices.reshape(-1, 3), + active_vertices=active_vertices.reshape(-1) + if active_vertices is not None + else None, + **kwargs, + ).reshape((*batch, num_triangles)) - return visible_triangles + return jax.lax.map(f, tuple(xs)).reshape((*batch, num_triangles)) @eqx.filter_jit @@ -912,7 +1164,7 @@ def first_triangles_hit_by_rays( ``*batch num_triangles 3`` (or bigger) is not possible, and you are only interested in getting the first triangle hit by the ray. - If two or more triangles are hit at the same distance, the one with the closest center to the ray origin is selected. Two triangles are considered to be hit at the same distance if their distances differ by less than ``100 * eps``, or ten times the ``epsilon`` keyword argument passed to :func:`rays_intersect_triangles`. + If two or more triangles are hit at the same distance, the one with the closest center to the ray origin is selected. Two triangles are considered to be hit at the same distance if their distances differ by less than ``100 * eps``, where ``eps`` is the machine epsilon of the floating point dtype. Args: ray_origins: An array of origin vertices. @@ -951,25 +1203,20 @@ def first_triangles_hit_by_rays( ray_directions = jnp.asarray(ray_directions) triangle_vertices = jnp.asarray(triangle_vertices) - if epsilon := kwargs.get("epsilon"): - epsilon = 10 * jnp.asarray(epsilon) - else: - dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) - epsilon = jnp.asarray(100 * jnp.finfo(dtype).eps) - - num_triangles = triangle_vertices.shape[-3] - if active_triangles is not None: active_triangles = jnp.asarray(active_triangles) - # Combine the batch dimensions + dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) + dist_tol = jnp.asarray(100 * jnp.finfo(dtype).eps) + + num_triangles = triangle_vertices.shape[-3] + batch = jnp.broadcast_shapes( ray_origins.shape[:-1], ray_directions.shape[:-1], triangle_vertices.shape[:-3], active_triangles.shape[:-1] if active_triangles is not None else (), ) - num_rays = math.prod(batch) ray_batch_size = ( ray_batch_size @@ -978,6 +1225,7 @@ def first_triangles_hit_by_rays( if batch_size is not None else num_rays ) + ray_batch_size = min(ray_batch_size, num_rays) tri_batch_size = ( tri_batch_size if tri_batch_size is not None @@ -985,10 +1233,7 @@ def first_triangles_hit_by_rays( if batch_size is not None else num_triangles ) - - ray_chunk_size = max(min(ray_batch_size, num_rays), 1) - tri_chunk_size = max(min(tri_batch_size, num_triangles), 1) - + tri_batch_size = min(tri_batch_size, num_triangles) if num_triangles == 0: # If there are no triangles, there are no hits return ( @@ -1000,232 +1245,58 @@ def first_triangles_hit_by_rays( ), ) - def reduce_fn( - left: tuple[ - Int[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *#batch"], - ], - right: tuple[ - Int[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *#batch"], - ], - ) -> tuple[ - Int[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *batch"], - Float[Array, " *#batch"], - ]: - left_indices, left_t, left_center_distances, eps = left - right_indices, right_t, right_center_distances, _ = right - cond: Array = jnp.where( - jnp.abs(left_t - right_t) < eps, - left_center_distances < right_center_distances, - left_t < right_t, - ) - t = jnp.where(cond, left_t, right_t) - indices = jnp.where(cond, left_indices, right_indices) - t = jnp.minimum(left_t, right_t) - center_distances = jnp.where( - cond, left_center_distances, right_center_distances - ) - is_finite = jnp.isfinite(t) - indices = jnp.where(is_finite, indices, -1) - t = jnp.where(is_finite, t, jnp.inf) - return indices, t, center_distances, eps - - def _process_ray_batch( - ray_origins_batch: Float[Array, "*#batch 3"], - ray_directions_batch: Float[Array, "*#batch 3"], - triangle_vertices_batch: Float[Array, "*#batch num_triangles 3 3"], - active_triangles_batch: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: - # Process one batch of rays through all triangles. - - def map_fn( - ray_origins: Float[Array, "*#batch 3"], - ray_directions: Float[Array, "*#batch 3"], - triangle_vertices: Float[Array, "*#batch num_triangles 3 3"], - active_triangles: Bool[Array, "*#batch num_triangles"] | None = None, - ) -> tuple[ - Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] - ]: - t, hit = rays_intersect_triangles( - ray_origins[..., None, :], - ray_directions[..., None, :], - triangle_vertices, - **kwargs, - ) - if active_triangles is not None: - hit &= active_triangles - t = jnp.where(hit, t, jnp.inf) - indices = jnp.arange(triangle_vertices.shape[-3]) - indices = jnp.broadcast_to(indices, t.shape) - center_distances = jnp.linalg.norm( - triangle_vertices.mean(axis=-2) - ray_origins[..., None, :], axis=-1 - ) - center_distances = jnp.broadcast_to(center_distances, t.shape) - eps = jnp.broadcast_to(epsilon, t.shape) - return jax.lax.reduce( - (indices, t, center_distances, eps), - (-1, jnp.inf, jnp.inf, epsilon), - reduce_fn, - dimensions=(t.ndim - 1,), - )[:3] - - num_triangle_batches, rem_triangles = divmod(num_triangles, tri_chunk_size) - - def body_fun( - batch_index: Int[Array, ""], - carry: tuple[ - Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] - ], - ) -> tuple[ - Int[Array, " *batch"], Float[Array, " *batch"], Float[Array, " *batch"] - ]: - start_index = batch_index * tri_chunk_size - batch_of_triangle_vertices = jax.lax.dynamic_slice_in_dim( - triangle_vertices_batch, start_index, tri_chunk_size, axis=-3 - ) - batch_of_active_triangles = ( - jax.lax.dynamic_slice_in_dim( - active_triangles_batch, start_index, tri_chunk_size, axis=-1 - ) - if active_triangles_batch is not None - else None - ) - indices, t, center_distances = map_fn( - ray_origins_batch, - ray_directions_batch, - batch_of_triangle_vertices, - batch_of_active_triangles, - ) - return reduce_fn( - (carry[0], carry[1], carry[2], epsilon), - (indices + start_index, t, center_distances, epsilon), - )[:3] - - init_val = ( - -jnp.ones(ray_origins_batch.shape[:-1], dtype=jnp.int32), - jnp.full( - ray_origins_batch.shape[:-1], - jnp.inf, - dtype=jnp.result_type( - ray_origins_batch, ray_directions_batch, triangle_vertices_batch - ), - ), - jnp.full( - ray_origins_batch.shape[:-1], - jnp.inf, - dtype=jnp.result_type( - ray_origins_batch, ray_directions_batch, triangle_vertices_batch - ), - ), - ) - - indices, t, center_distances = jax.lax.fori_loop( - 0, - num_triangle_batches, - body_fun, - init_val=init_val, - ) - - if rem_triangles > 0: - rem_indices, rem_t, rem_center_distances = map_fn( - ray_origins_batch, - ray_directions_batch, - triangle_vertices_batch[..., -rem_triangles:, :, :], - active_triangles_batch[..., -rem_triangles:] - if active_triangles_batch is not None - else None, - ) - indices, t, _ = reduce_fn( - (indices, t, center_distances, epsilon), - ( - rem_indices + num_triangle_batches * tri_chunk_size, - rem_t, - rem_center_distances, - epsilon, - ), - )[:3] - - return (indices, t) - - ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)).reshape(-1, 3) - ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)).reshape(-1, 3) - triangle_vertices = jnp.broadcast_to( - triangle_vertices, (*batch, num_triangles, 3, 3) - ).reshape(-1, num_triangles, 3, 3) - if active_triangles is not None: - active_triangles = jnp.broadcast_to( - active_triangles, (*batch, num_triangles) - ).reshape(-1, num_triangles) - - num_ray_batches, rem_rays = divmod(num_rays, ray_chunk_size) - - init_indices = -jnp.ones(num_rays, dtype=jnp.int32) - init_t = jnp.full( - num_rays, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), + xs = [] + argnames = [] + map_fn = partial( + _first_triangle_hit_by_ray, + batch_size=tri_batch_size, + dist_tol=dist_tol, + **kwargs, ) - - def ray_body_fun( - batch_index: Int[Array, ""], - carry: tuple[Int[Array, "*"], Float[Array, "*"]], - ) -> tuple[Int[Array, "*"], Float[Array, "*"]]: - start_index = batch_index * ray_chunk_size - ray_origins_batch = jax.lax.dynamic_slice_in_dim( - ray_origins, start_index, ray_chunk_size, axis=0 - ) - ray_directions_batch = jax.lax.dynamic_slice_in_dim( - ray_directions, start_index, ray_chunk_size, axis=0 - ) - triangle_vertices_batch = jax.lax.dynamic_slice_in_dim( - triangle_vertices, start_index, ray_chunk_size, axis=0 + if math.prod(ray_origins.shape[:-1]) > 1: + xs.append(ray_origins.reshape(-1, 3)) + argnames.append("ray_origin") + else: + map_fn = partial(map_fn, ray_origin=ray_origins.reshape(3)) + if math.prod(ray_directions.shape[:-1]) > 1: + xs.append(ray_directions.reshape(-1, 3)) + argnames.append("ray_direction") + else: + map_fn = partial(map_fn, ray_direction=ray_directions.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial( + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) ) - active_triangles_batch = ( - jax.lax.dynamic_slice_in_dim( - active_triangles, start_index, ray_chunk_size, axis=0 - ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) if active_triangles is not None - else None - ) - indices_batch, t_batch = _process_ray_batch( - ray_origins_batch, - ray_directions_batch, - triangle_vertices_batch, - active_triangles_batch, + else None, ) - indices = jax.lax.dynamic_update_slice(carry[0], indices_batch, (start_index,)) - t = jax.lax.dynamic_update_slice(carry[1], t_batch, (start_index,)) - return (indices, t) - indices, t = jax.lax.fori_loop( - 0, - num_ray_batches, - ray_body_fun, - init_val=(init_indices, init_t), - ) + def f(args): + return map_fn(**dict(zip(argnames, args, strict=True))) - if rem_rays > 0: - start_index = num_ray_batches * ray_chunk_size - ray_origins_batch = ray_origins[-rem_rays:, :] - ray_directions_batch = ray_directions[-rem_rays:, :] - triangle_vertices_batch = triangle_vertices[-rem_rays:, :, :, :] - active_triangles_batch = ( - active_triangles[-rem_rays:, :] if active_triangles is not None else None - ) - rem_indices, rem_t = _process_ray_batch( - ray_origins_batch, - ray_directions_batch, - triangle_vertices_batch, - active_triangles_batch, + if not xs: + indices, ts = _first_triangle_hit_by_ray( + ray_origins.reshape(3), + ray_directions.reshape(3), + triangle_vertices.reshape(num_triangles, 3, 3), + active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + batch_size=tri_batch_size, + dist_tol=dist_tol, + **kwargs, ) - indices = jax.lax.dynamic_update_slice(indices, rem_indices, (start_index,)) - t = jax.lax.dynamic_update_slice(t, rem_t, (start_index,)) + return indices.reshape(batch), ts.reshape(batch) - return (indices.reshape(batch), t.reshape(batch)) + indices, ts = jax.lax.map(f, tuple(xs), batch_size=ray_batch_size) + return indices.reshape(batch), ts.reshape(batch) From 4207324582ac1fdaf264c6c349d43900ccf5f87c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 21:11:07 +0200 Subject: [PATCH 08/13] fix(tests): atol --- differt/tests/rt/test_utils.py | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/differt/tests/rt/test_utils.py b/differt/tests/rt/test_utils.py index 2e5723b5..3576c624 100644 --- a/differt/tests/rt/test_utils.py +++ b/differt/tests/rt/test_utils.py @@ -571,4 +571,6 @@ def test_first_triangles_hit_by_rays( # TODO: fixme, we need to fix the index if two or more triangles are hit at the same t # chex.assert_trees_all_equal(got_indices, expected_indices) - chex.assert_trees_all_close(got_t, expected_t, rtol=1e-5) + dtype = got_t.dtype + dist_tol = 100 * jnp.finfo(dtype).eps + chex.assert_trees_all_close(got_t, expected_t, rtol=1e-5, atol=dist_tol) From 348d0c17472ac5dc64634a0f8b421d98314ca6d8 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 21:34:35 +0200 Subject: [PATCH 09/13] fix(docs): batch size is not data --- differt/src/differt/rt/_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index e35c8e5c..ef577325 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -663,14 +663,14 @@ def rays_intersect_any_triangle( The ray batch size is automatically adjusted to be the minimum of the number of rays and the specified ray batch size. - If :data:`None`, it defaults to :data:`batch_size`. + If :data:`None`, it defaults to ``batch_size``. tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. The triangle batch size is automatically adjusted to be the minimum of the number of triangles and the specified triangle batch size. - If :data:`None`, it defaults to :data:`batch_size`. + If :data:`None`, it defaults to ``batch_size``. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. From d0c96bb5af412e4c11c7376e350d9dc68faefa1b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 22:19:17 +0200 Subject: [PATCH 10/13] perf: trying to improve performance --- differt/src/differt/rt/_utils.py | 113 ++++--------------- differt/src/differt/scene/_triangle_scene.py | 2 +- 2 files changed, 26 insertions(+), 89 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index ef577325..22fece61 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -9,7 +9,7 @@ import jax.numpy as jnp from jaxtyping import Array, ArrayLike, Bool, Float, Int -from differt.geometry import viewing_frustum +from differt.geometry import fibonacci_lattice, viewing_frustum from differt.utils import smoothing_function from differt_core.rt import CompleteGraph @@ -612,7 +612,7 @@ def rays_intersect_any_triangle( *, hit_tol: Float[ArrayLike, ""] | None = None, smoothing_factor: Float[ArrayLike, ""] | None = None, - batch_size: int | None = 1024, + batch_size: int | None = 4096, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, @@ -783,29 +783,6 @@ def f(args): )(ray_origins, ray_directions, triangle_vertices, active_triangles) -def _fibonacci_lattice_chunk( - start_index, - chunk_size, - total_n, - frustum, -): - phi = 1.618033988749895 # golden ratio - i = jnp.arange(chunk_size) + start_index - - lat = jnp.arccos(1 - 2 * i / total_n) - lon = 2 * jnp.pi * i / phi - - pa = jnp.stack((lat, lon), axis=-1) - - if frustum is not None: - pa %= frustum[1, -2:] - frustum[0, -2:] - pa += frustum[0, -2:] - - from differt.geometry import spherical_to_cartesian - - return spherical_to_cartesian(pa) - - def _triangles_visible_from_vertex( vertex, triangle_vertices, @@ -820,68 +797,21 @@ def _triangles_visible_from_vertex( **kwargs, ): frustum = viewing_frustum(vertex, world_vertices, active_vertices=active_vertices) + ray_directions = fibonacci_lattice(num_rays, frustum=frustum) - def body_fn( - batch_index: Int[Array, ""], - visible_so_far: Bool[Array, " num_triangles"], - ) -> Bool[Array, " num_triangles"]: - start_index = batch_index * ray_batch_size - ray_directions = _fibonacci_lattice_chunk( - start_index, ray_batch_size, num_rays, frustum - ) - - indices, _ = first_triangles_hit_by_rays( - vertex, - ray_directions, - triangle_vertices, - active_triangles=active_triangles, - ray_batch_size=None, - tri_batch_size=tri_batch_size, - **kwargs, - ) - - valid_hits = indices >= 0 - safe_indices = jnp.where(valid_hits, indices, 0) - visible_in_batch = ( - jnp.zeros(num_triangles, dtype=bool).at[safe_indices].max(valid_hits) - ) - - return visible_so_far | visible_in_batch - - num_ray_batches, rem_rays = divmod(num_rays, ray_batch_size) - - visible = jax.lax.fori_loop( - 0, - num_ray_batches, - body_fn, - init_val=jnp.zeros(num_triangles, dtype=bool), + indices, _ = first_triangles_hit_by_rays( + vertex, + ray_directions, + triangle_vertices, + active_triangles=active_triangles, + ray_batch_size=ray_batch_size, + tri_batch_size=tri_batch_size, + **kwargs, ) - if rem_rays > 0: - start_index = num_ray_batches * ray_batch_size - ray_directions = _fibonacci_lattice_chunk( - start_index, rem_rays, num_rays, frustum - ) - indices, _ = first_triangles_hit_by_rays( - vertex, - ray_directions, - triangle_vertices, - active_triangles=active_triangles, - ray_batch_size=None, - tri_batch_size=tri_batch_size, - **kwargs, - ) - valid_hits = indices >= 0 - safe_indices = jnp.where(valid_hits, indices, 0) - visible_in_batch = ( - jnp.bincount( - safe_indices, - weights=valid_hits.astype(jnp.int32), - length=num_triangles, - ) - > 0 - ) - visible |= visible_in_batch + valid_hits = indices >= 0 + safe_indices = jnp.where(valid_hits, indices, 0) + visible = jnp.zeros(num_triangles, dtype=bool).at[safe_indices].max(valid_hits) return visible @@ -892,7 +822,7 @@ def triangles_visible_from_vertices( triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, num_rays: int = int(1e6), - batch_size: int | None = 1024, + batch_size: int | None = 4096, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, @@ -1143,7 +1073,10 @@ def f(args): **kwargs, ).reshape((*batch, num_triangles)) - return jax.lax.map(f, tuple(xs)).reshape((*batch, num_triangles)) + return jax.lax.map(f, tuple(xs), batch_size=batch_size).reshape(( + *batch, + num_triangles, + )) @eqx.filter_jit @@ -1152,7 +1085,7 @@ def first_triangles_hit_by_rays( ray_directions: Float[ArrayLike, "*#batch 3"], triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, - batch_size: int | None = 1024, + batch_size: int | None = 4096, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, @@ -1298,5 +1231,9 @@ def f(args): ) return indices.reshape(batch), ts.reshape(batch) - indices, ts = jax.lax.map(f, tuple(xs), batch_size=ray_batch_size) + if num_rays > ray_batch_size: + indices, ts = jax.lax.map(f, tuple(xs), batch_size=ray_batch_size) + else: + indices, ts = jax.vmap(f)(jnp.stack(xs, axis=1)) + return indices.reshape(batch), ts.reshape(batch) diff --git a/differt/src/differt/scene/_triangle_scene.py b/differt/src/differt/scene/_triangle_scene.py index ad6ebea6..728e5e7a 100644 --- a/differt/src/differt/scene/_triangle_scene.py +++ b/differt/src/differt/scene/_triangle_scene.py @@ -1031,7 +1031,7 @@ def compute_paths( max_dist: Float[ArrayLike, ""] = 1e-3, smoothing_factor: Float[ArrayLike, ""] | None = None, confidence_threshold: Float[ArrayLike, ""] = 0.5, - batch_size: int | None = 512, + batch_size: int | None = 4096, disconnect_inactive_triangles: bool = False, ) -> Paths[_M] | SizedIterator[Paths[_M]] | Iterator[Paths[_M]] | SBRPaths: """ From 40b3121201cf05c877e8f9dfe643029904cf963c Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Wed, 13 May 2026 23:25:34 +0200 Subject: [PATCH 11/13] perf(lib): better batch size default --- differt/src/differt/rt/_utils.py | 73 ++++++++++---------- differt/src/differt/scene/_triangle_scene.py | 2 +- 2 files changed, 36 insertions(+), 39 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 22fece61..17f4503c 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -439,15 +439,15 @@ def body_fn( def _first_triangle_hit_by_ray_batched( - ray_origin, - ray_direction, - triangle_vertices, - active_triangles, - triangle_indices, + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "batch_size 3 3"], + active_triangles: Bool[Array, " batch_size"] | None, + triangle_indices: Int[Array, " batch_size"], *, - dist_tol, - **kwargs, -): + dist_tol: Float[Array, ""], + **kwargs: Any, +) -> tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]]: ts, hits = jax.vmap( partial(rays_intersect_triangles, **kwargs), in_axes=(None, None, 0), @@ -480,15 +480,15 @@ def _first_triangle_hit_by_ray_batched( def _first_triangle_hit_by_ray( - ray_origin, - ray_direction, - triangle_vertices, - active_triangles, + ray_origin: Float[Array, "3"], + ray_direction: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, *, - batch_size, - dist_tol, - **kwargs, -): + batch_size: int, + dist_tol: Float[Array, ""], + **kwargs: Any, +) -> tuple[Int[Array, ""], Float[Array, ""]]: def combine_best( left: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], right: tuple[Int[Array, ""], Float[Array, ""], Float[Array, ""]], @@ -612,7 +612,7 @@ def rays_intersect_any_triangle( *, hit_tol: Float[ArrayLike, ""] | None = None, smoothing_factor: Float[ArrayLike, ""] | None = None, - batch_size: int | None = 4096, + batch_size: int | None = 1024, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, @@ -765,7 +765,7 @@ def rays_intersect_any_triangle( else: map_fn = partial(map_fn, active_triangles=active_triangles) - def f(args): + def f(args: tuple[Array, ...]) -> Bool[Array, ""] | Float[Array, ""]: return map_fn(**dict(zip(argnames, args, strict=True))) return jax.lax.map(f, tuple(xs), batch_size=ray_batch_size).reshape(batch) @@ -784,18 +784,18 @@ def f(args): def _triangles_visible_from_vertex( - vertex, - triangle_vertices, - active_triangles, + vertex: Float[Array, "3"], + triangle_vertices: Float[Array, "num_triangles 3 3"], + active_triangles: Bool[Array, " num_triangles"] | None, *, - num_rays, - num_triangles, - ray_batch_size, - tri_batch_size, - world_vertices, - active_vertices, - **kwargs, -): + num_rays: int, + num_triangles: int, + ray_batch_size: int | None, + tri_batch_size: int | None, + world_vertices: Float[Array, "num_world_vertices 3"], + active_vertices: Bool[Array, " num_world_vertices"] | None, + **kwargs: Any, +) -> Bool[Array, " num_triangles"]: frustum = viewing_frustum(vertex, world_vertices, active_vertices=active_vertices) ray_directions = fibonacci_lattice(num_rays, frustum=frustum) @@ -811,9 +811,7 @@ def _triangles_visible_from_vertex( valid_hits = indices >= 0 safe_indices = jnp.where(valid_hits, indices, 0) - visible = jnp.zeros(num_triangles, dtype=bool).at[safe_indices].max(valid_hits) - - return visible + return jnp.zeros(num_triangles, dtype=bool).at[safe_indices].max(valid_hits) @eqx.filter_jit @@ -822,7 +820,7 @@ def triangles_visible_from_vertices( triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, num_rays: int = int(1e6), - batch_size: int | None = 4096, + batch_size: int | None = 1024, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, @@ -978,7 +976,6 @@ def triangles_visible_from_vertices( triangle_vertices.shape[:-3], active_triangles.shape[:-1] if active_triangles is not None else (), ) - num_vertices = math.prod(batch) ray_batch_size = ( ray_batch_size @@ -1051,7 +1048,7 @@ def triangles_visible_from_vertices( else None, ) - def f(args): + def f(args: tuple[Array, ...]) -> Bool[Array, "num_triangles"]: return map_fn(**dict(zip(argnames, args, strict=True))) if not xs: @@ -1085,7 +1082,7 @@ def first_triangles_hit_by_rays( ray_directions: Float[ArrayLike, "*#batch 3"], triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, - batch_size: int | None = 4096, + batch_size: int | None = 1024, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, @@ -1214,7 +1211,7 @@ def first_triangles_hit_by_rays( else None, ) - def f(args): + def f(args: tuple[Array, ...]) -> tuple[Int[Array, ""], Float[Array, ""]]: return map_fn(**dict(zip(argnames, args, strict=True))) if not xs: @@ -1234,6 +1231,6 @@ def f(args): if num_rays > ray_batch_size: indices, ts = jax.lax.map(f, tuple(xs), batch_size=ray_batch_size) else: - indices, ts = jax.vmap(f)(jnp.stack(xs, axis=1)) + indices, ts = jax.vmap(f)(tuple(xs)) return indices.reshape(batch), ts.reshape(batch) diff --git a/differt/src/differt/scene/_triangle_scene.py b/differt/src/differt/scene/_triangle_scene.py index 728e5e7a..7169bdcf 100644 --- a/differt/src/differt/scene/_triangle_scene.py +++ b/differt/src/differt/scene/_triangle_scene.py @@ -1031,7 +1031,7 @@ def compute_paths( max_dist: Float[ArrayLike, ""] = 1e-3, smoothing_factor: Float[ArrayLike, ""] | None = None, confidence_threshold: Float[ArrayLike, ""] = 0.5, - batch_size: int | None = 4096, + batch_size: int | None = 1024, disconnect_inactive_triangles: bool = False, ) -> Paths[_M] | SizedIterator[Paths[_M]] | Iterator[Paths[_M]] | SBRPaths: """ From a02abdc13a79de116793e41ffdbd938320ef5555 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 14 May 2026 10:49:17 +0200 Subject: [PATCH 12/13] cleanup --- differt/src/differt/rt/_utils.py | 591 ++++++++++++++++--------------- 1 file changed, 303 insertions(+), 288 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 17f4503c..3a117e11 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -438,6 +438,229 @@ def body_fn( return intersect +@overload +def rays_intersect_any_triangle( + ray_origins: Float[ArrayLike, "*#batch 3"], + ray_directions: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = ..., + *, + hit_tol: Float[ArrayLike, ""] | None = ..., + smoothing_factor: None = ..., + batch_size: int | None = ..., + ray_batch_size: int | None = ..., + tri_batch_size: int | None = ..., + **kwargs: Any, +) -> Bool[Array, " *batch"]: ... + + +@overload +def rays_intersect_any_triangle( + ray_origins: Float[ArrayLike, "*#batch 3"], + ray_directions: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = ..., + *, + hit_tol: Float[ArrayLike, ""] | None = ..., + smoothing_factor: Float[ArrayLike, ""], + batch_size: int | None = ..., + ray_batch_size: int | None = ..., + tri_batch_size: int | None = ..., + **kwargs: Any, +) -> Float[Array, " *batch"]: ... + + +@eqx.filter_jit +def rays_intersect_any_triangle( + ray_origins: Float[ArrayLike, "*#batch 3"], + ray_directions: Float[ArrayLike, "*#batch 3"], + triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], + active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, + *, + hit_tol: Float[ArrayLike, ""] | None = None, + smoothing_factor: Float[ArrayLike, ""] | None = None, + batch_size: int | None = 1024, + ray_batch_size: int | None = None, + tri_batch_size: int | None = None, + **kwargs: Any, +) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: + """ + Return whether rays intersect any of the triangles using the Möller-Trumbore algorithm. + + This function should be used when allocating an array of size + ``*batch num_triangles 3`` (or bigger) is not possible, and you are only interested in + checking if at least one of the triangles is intersected. + + A triangle is considered to be intersected if + ``t < (1 - hit_tol) & hit`` evaluates to :data:`True`. + + Args: + ray_origins: An array of origin vertices. + ray_directions: An array of ray direction. The ray ends + should be equal to ``ray_origins + ray_directions``. + triangle_vertices: An array of triangle vertices. + active_triangles: An optional array of boolean values indicating + which triangles are active, i.e., should be considered for intersection. + + If not specified, all triangles are considered active. + hit_tol: The tolerance applied to check if a ray hits another object or not, + before it reaches the expected position, i.e., the 'interaction' object. + + Using a non-zero tolerance is required as it would otherwise trigger + false positives. + + If not specified, the default is ten times the epsilon value + of the currently used floating point dtype. + smoothing_factor: If set, hard conditions are replaced with smoothed ones, + as described in :cite:`fully-eucap2024`, and this argument parameterizes the slope + of the smoothing function. The second output value is now a real value + between 0 (:data:`False`) and 1 (:data:`True`). + + For more details, refer to :ref:`smoothing`. + batch_size: The default batch size used when either ``ray_batch_size`` or + ``tri_batch_size`` is not specified. This allows to make a trade-off between memory + usage and performance. + + If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are + used. Otherwise, this value is used as the default for whichever of + ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. + ray_batch_size: The number of rays to process in a single batch. + This allows to make a trade-off between memory usage and performance. + + The ray batch size is automatically adjusted to be the minimum of the number of rays + and the specified ray batch size. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all rays are processed in a single batch. + tri_batch_size: The number of triangles to process in a single batch. + This allows to make a trade-off between memory usage and performance. + + The triangle batch size is automatically adjusted to be the minimum of the number of + triangles and the specified triangle batch size. + + If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all triangles are processed in a single batch. + kwargs: Keyword arguments passed to + :func:`rays_intersect_triangles`. + + Returns: + For each ray, whether it intersects with any of the triangles. + """ + ray_origins = jnp.asarray(ray_origins) + ray_directions = jnp.asarray(ray_directions) + triangle_vertices = jnp.asarray(triangle_vertices) + + if hit_tol is None: + dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) + hit_tol = 10.0 * jnp.finfo(dtype).eps + + hit_threshold = 1.0 - jnp.asarray(hit_tol) + + num_triangles = triangle_vertices.shape[-3] + + if active_triangles is not None: + active_triangles = jnp.asarray(active_triangles) + + batch = jnp.broadcast_shapes( + ray_origins.shape[:-1], + ray_directions.shape[:-1], + triangle_vertices.shape[:-3], + active_triangles.shape[:-1] if active_triangles is not None else (), + ) + num_rays = math.prod(batch) + ray_batch_size = ( + ray_batch_size + if ray_batch_size is not None + else batch_size + if batch_size is not None + else num_rays + ) + ray_batch_size = min(ray_batch_size, num_rays) + tri_batch_size = ( + tri_batch_size + if tri_batch_size is not None + else batch_size + if batch_size is not None + else num_triangles + ) + tri_batch_size = min(tri_batch_size, num_triangles) + if num_triangles == 0: + # If there are no triangles, there are no intersections + return ( + jnp.zeros( + batch, + dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), + ) + if smoothing_factor is not None + else jnp.zeros(batch, dtype=bool) + ) + + if num_rays > ray_batch_size: + xs = [] + argnames = [] + map_fn = partial( + _ray_intersect_any_triangle, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + batch_size=tri_batch_size, + **kwargs, + ) + if math.prod(ray_origins.shape[:-1]) > 1: + ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)) + xs.append(ray_origins.reshape(-1, 3)) + argnames.append("ray_origin") + else: + map_fn = partial(map_fn, ray_origin=ray_origins.reshape(3)) + if math.prod(ray_directions.shape[:-1]) > 1: + ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)) + xs.append(ray_directions.reshape(-1, 3)) + argnames.append("ray_direction") + else: + map_fn = partial(map_fn, ray_direction=ray_directions.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + triangle_vertices = jnp.broadcast_to( + triangle_vertices, (*batch, num_triangles, 3, 3) + ) + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: + map_fn = partial( + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) + ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + active_triangles = jnp.broadcast_to( + active_triangles, (*batch, num_triangles) + ) + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, + ) + + def f(*args: Array) -> Bool[Array, ""] | Float[Array, ""]: + return map_fn(**dict(zip(argnames, args, strict=True))) + + return jax.lax.map( + lambda x: f(*x), tuple(xs), batch_size=ray_batch_size + ).reshape(batch) + return jnp.vectorize( + partial( + _ray_intersect_any_triangle, + hit_threshold=hit_threshold, + smoothing_factor=smoothing_factor, + batch_size=tri_batch_size, + **kwargs, + ), + signature="(3),(3),(n,3,3),(n)->()" + if active_triangles is not None + else "(3),(3),(n,3,3),()->()", + )(ray_origins, ray_directions, triangle_vertices, active_triangles) + + def _first_triangle_hit_by_ray_batched( ray_origin: Float[Array, "3"], ray_direction: Float[Array, "3"], @@ -523,7 +746,7 @@ def body_fn( if active_triangles is not None else None ) - batch_of_indices = jnp.arange(batch_size, dtype=jnp.int32) + start_index + batch_of_indices = jnp.arange(batch_size, dtype=int) + start_index best_in_batch = _first_triangle_hit_by_ray_batched( ray_origin, @@ -541,7 +764,7 @@ def body_fn( num_batches, rem = divmod(num_triangles, batch_size) init_val = ( - jnp.array(-1, dtype=jnp.int32), + jnp.array(-1, dtype=int), jnp.array(jnp.inf, dtype=ray_origin.dtype), jnp.array(jnp.inf, dtype=ray_origin.dtype), ) @@ -562,7 +785,7 @@ def body_fn( ray_direction, triangle_vertices[-rem:, :], (active_triangles[-rem:] if active_triangles is not None else None), - jnp.arange(rem, dtype=jnp.int32) + start_index, + jnp.arange(rem, dtype=int) + start_index, dist_tol=dist_tol, **kwargs, ), @@ -571,61 +794,25 @@ def body_fn( return best[0], best[1] -@overload -def rays_intersect_any_triangle( - ray_origins: Float[ArrayLike, "*#batch 3"], - ray_directions: Float[ArrayLike, "*#batch 3"], - triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], - active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = ..., - *, - hit_tol: Float[ArrayLike, ""] | None = ..., - smoothing_factor: None = ..., - batch_size: int | None = ..., - ray_batch_size: int | None = ..., - tri_batch_size: int | None = ..., - **kwargs: Any, -) -> Bool[Array, " *batch"]: ... - - -@overload -def rays_intersect_any_triangle( - ray_origins: Float[ArrayLike, "*#batch 3"], - ray_directions: Float[ArrayLike, "*#batch 3"], - triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], - active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = ..., - *, - hit_tol: Float[ArrayLike, ""] | None = ..., - smoothing_factor: Float[ArrayLike, ""], - batch_size: int | None = ..., - ray_batch_size: int | None = ..., - tri_batch_size: int | None = ..., - **kwargs: Any, -) -> Float[Array, " *batch"]: ... - - @eqx.filter_jit -def rays_intersect_any_triangle( +def first_triangles_hit_by_rays( ray_origins: Float[ArrayLike, "*#batch 3"], ray_directions: Float[ArrayLike, "*#batch 3"], triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, - *, - hit_tol: Float[ArrayLike, ""] | None = None, - smoothing_factor: Float[ArrayLike, ""] | None = None, batch_size: int | None = 1024, ray_batch_size: int | None = None, tri_batch_size: int | None = None, **kwargs: Any, -) -> Bool[Array, " *batch"] | Float[Array, " *batch"]: +) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: """ - Return whether rays intersect any of the triangles using the Möller-Trumbore algorithm. + Return the first triangle hit by each ray. This function should be used when allocating an array of size ``*batch num_triangles 3`` (or bigger) is not possible, and you are only interested in - checking if at least one of the triangles is intersected. + getting the first triangle hit by the ray. - A triangle is considered to be intersected if - ``t < (1 - hit_tol) & hit`` evaluates to :data:`True`. + If two or more triangles are hit at the same distance, the one with the closest center to the ray origin is selected. Two triangles are considered to be hit at the same distance if their distances differ by less than ``100 * eps``, where ``eps`` is the machine epsilon of the floating point dtype. Args: ray_origins: An array of origin vertices. @@ -636,20 +823,6 @@ def rays_intersect_any_triangle( which triangles are active, i.e., should be considered for intersection. If not specified, all triangles are considered active. - hit_tol: The tolerance applied to check if a ray hits another object or not, - before it reaches the expected position, i.e., the 'interaction' object. - - Using a non-zero tolerance is required as it would otherwise trigger - false positives. - - If not specified, the default is ten times the epsilon value - of the currently used floating point dtype. - smoothing_factor: If set, hard conditions are replaced with smoothed ones, - as described in :cite:`fully-eucap2024`, and this argument parameterizes the slope - of the smoothing function. The second output value is now a real value - between 0 (:data:`False`) and 1 (:data:`True`). - - For more details, refer to :ref:`smoothing`. batch_size: The default batch size used when either ``ray_batch_size`` or ``tri_batch_size`` is not specified. This allows to make a trade-off between memory usage and performance. @@ -658,40 +831,36 @@ def rays_intersect_any_triangle( used. Otherwise, this value is used as the default for whichever of ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. ray_batch_size: The number of rays to process in a single batch. - This allows to make a trade-off between memory usage and performance. - - The ray batch size is automatically adjusted to be the minimum of the number of rays - and the specified ray batch size. + This allows to chunk rays and reduce peak memory usage. If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all rays are processed in a single batch. tri_batch_size: The number of triangles to process in a single batch. - This allows to make a trade-off between memory usage and performance. - - The triangle batch size is automatically adjusted to be the minimum of the number of - triangles and the specified triangle batch size. + This allows to chunk triangles and reduce peak memory usage. If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all triangles are processed in a single batch. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. Returns: - For each ray, whether it intersects with any of the triangles. + For each ray, return the index and to distance to the first triangle hit. + + If no triangle is hit, the index is set to ``-1`` and + the distance is set to :data:`inf`. """ ray_origins = jnp.asarray(ray_origins) ray_directions = jnp.asarray(ray_directions) triangle_vertices = jnp.asarray(triangle_vertices) - if hit_tol is None: - dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) - hit_tol = 10.0 * jnp.finfo(dtype).eps + if active_triangles is not None: + active_triangles = jnp.asarray(active_triangles) - hit_threshold = 1.0 - jnp.asarray(hit_tol) + dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) + dist_tol = jnp.asarray(100 * jnp.finfo(dtype).eps) num_triangles = triangle_vertices.shape[-3] - if active_triangles is not None: - active_triangles = jnp.asarray(active_triangles) - batch = jnp.broadcast_shapes( ray_origins.shape[:-1], ray_directions.shape[:-1], @@ -716,71 +885,75 @@ def rays_intersect_any_triangle( ) tri_batch_size = min(tri_batch_size, num_triangles) if num_triangles == 0: - # If there are no triangles, there are no intersections + # If there are no triangles, there are no hits return ( - jnp.zeros( + jnp.full(batch, -1, dtype=int), + jnp.full( batch, + jnp.inf, dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ) - if smoothing_factor is not None - else jnp.zeros(batch, dtype=bool) + ), ) - if num_rays > ray_batch_size: - xs = [] - argnames = [] + xs = [] + argnames = [] + map_fn = partial( + _first_triangle_hit_by_ray, + batch_size=tri_batch_size, + dist_tol=dist_tol, + **kwargs, + ) + if math.prod(ray_origins.shape[:-1]) > 1: + xs.append(ray_origins.reshape(-1, 3)) + argnames.append("ray_origin") + else: + map_fn = partial(map_fn, ray_origin=ray_origins.reshape(3)) + if math.prod(ray_directions.shape[:-1]) > 1: + xs.append(ray_directions.reshape(-1, 3)) + argnames.append("ray_direction") + else: + map_fn = partial(map_fn, ray_direction=ray_directions.reshape(3)) + if math.prod(triangle_vertices.shape[:-3]) > 1: + xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) + argnames.append("triangle_vertices") + else: map_fn = partial( - _ray_intersect_any_triangle, - hit_threshold=hit_threshold, - smoothing_factor=smoothing_factor, - batch_size=tri_batch_size, - **kwargs, + map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) + ) + if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: + xs.append(active_triangles.reshape(-1, num_triangles)) + argnames.append("active_triangles") + else: + map_fn = partial( + map_fn, + active_triangles=active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, ) - if math.prod(ray_origins.shape[:-1]) > 1: - ray_origins = jnp.broadcast_to(ray_origins, (*batch, 3)) - xs.append(ray_origins.reshape(-1, 3)) - argnames.append("ray_origin") - else: - map_fn = partial(map_fn, ray_origin=ray_origins) - if math.prod(ray_directions.shape[:-1]) > 1: - ray_directions = jnp.broadcast_to(ray_directions, (*batch, 3)) - xs.append(ray_directions.reshape(-1, 3)) - argnames.append("ray_direction") - else: - map_fn = partial(map_fn, ray_direction=ray_directions) - if math.prod(triangle_vertices.shape[:-3]) > 1: - triangle_vertices = jnp.broadcast_to( - triangle_vertices, (*batch, num_triangles, 3, 3) - ) - xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) - argnames.append("triangle_vertices") - else: - map_fn = partial(map_fn, triangle_vertices=triangle_vertices) - if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: - active_triangles = jnp.broadcast_to( - active_triangles, (*batch, num_triangles) - ) - xs.append(active_triangles.reshape(-1, num_triangles)) - argnames.append("active_triangles") - else: - map_fn = partial(map_fn, active_triangles=active_triangles) - def f(args: tuple[Array, ...]) -> Bool[Array, ""] | Float[Array, ""]: - return map_fn(**dict(zip(argnames, args, strict=True))) + def f(*args: Array) -> tuple[Int[Array, ""], Float[Array, ""]]: + return map_fn(**dict(zip(argnames, args, strict=True))) - return jax.lax.map(f, tuple(xs), batch_size=ray_batch_size).reshape(batch) - return jnp.vectorize( - partial( - _ray_intersect_any_triangle, - hit_threshold=hit_threshold, - smoothing_factor=smoothing_factor, + if not xs: + indices, ts = _first_triangle_hit_by_ray( + ray_origins.reshape(3), + ray_directions.reshape(3), + triangle_vertices.reshape(num_triangles, 3, 3), + active_triangles.reshape(num_triangles) + if active_triangles is not None + else None, batch_size=tri_batch_size, + dist_tol=dist_tol, **kwargs, - ), - signature="(3),(3),(n,3,3),(n)->()" - if active_triangles is not None - else "(3),(3),(n,3,3),()->()", - )(ray_origins, ray_directions, triangle_vertices, active_triangles) + ) + return indices.reshape(batch), ts.reshape(batch) + + if num_rays > ray_batch_size: + indices, ts = jax.lax.map(lambda x: f(*x), tuple(xs), batch_size=ray_batch_size) + else: + indices, ts = jax.vmap(f)(*xs) + + return indices.reshape(batch), ts.reshape(batch) def _triangles_visible_from_vertex( @@ -863,6 +1036,7 @@ def triangles_visible_from_vertices( and the specified ray batch size. If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all rays are processed in a single batch. tri_batch_size: The number of triangles to process in a single batch. This allows to make a trade-off between memory usage and performance. @@ -870,6 +1044,7 @@ def triangles_visible_from_vertices( triangles and the specified triangle batch size. If :data:`None`, it defaults to ``batch_size``. + If ``batch_size`` is also :data:`None`, all triangles are processed in a single batch. kwargs: Keyword arguments passed to :func:`rays_intersect_triangles`. @@ -1074,163 +1249,3 @@ def f(args: tuple[Array, ...]) -> Bool[Array, "num_triangles"]: *batch, num_triangles, )) - - -@eqx.filter_jit -def first_triangles_hit_by_rays( - ray_origins: Float[ArrayLike, "*#batch 3"], - ray_directions: Float[ArrayLike, "*#batch 3"], - triangle_vertices: Float[ArrayLike, "*#batch num_triangles 3 3"], - active_triangles: Bool[ArrayLike, "*#batch num_triangles"] | None = None, - batch_size: int | None = 1024, - ray_batch_size: int | None = None, - tri_batch_size: int | None = None, - **kwargs: Any, -) -> tuple[Int[Array, " *batch"], Float[Array, " *batch"]]: - """ - Return the first triangle hit by each ray. - - This function should be used when allocating an array of size - ``*batch num_triangles 3`` (or bigger) is not possible, and you are only interested in - getting the first triangle hit by the ray. - - If two or more triangles are hit at the same distance, the one with the closest center to the ray origin is selected. Two triangles are considered to be hit at the same distance if their distances differ by less than ``100 * eps``, where ``eps`` is the machine epsilon of the floating point dtype. - - Args: - ray_origins: An array of origin vertices. - ray_directions: An array of ray direction. The ray ends - should be equal to ``ray_origins + ray_directions``. - triangle_vertices: An array of triangle vertices. - active_triangles: An optional array of boolean values indicating - which triangles are active, i.e., should be considered for intersection. - - If not specified, all triangles are considered active. - batch_size: The default batch size used when either ``ray_batch_size`` or - ``tri_batch_size`` is not specified. This allows to make a trade-off between memory - usage and performance. - - If :data:`None`, the provided ``ray_batch_size`` and ``tri_batch_size`` values are - used. Otherwise, this value is used as the default for whichever of - ``ray_batch_size`` and ``tri_batch_size`` are left unspecified. - ray_batch_size: The number of rays to process in a single batch. - This allows to chunk rays and reduce peak memory usage. - - If :data:`None`, it defaults to ``batch_size``. - tri_batch_size: The number of triangles to process in a single batch. - This allows to chunk triangles and reduce peak memory usage. - - If :data:`None`, the triangle batch size defaults to ``batch_size``. - kwargs: Keyword arguments passed to - :func:`rays_intersect_triangles`. - - Returns: - For each ray, return the index and to distance to the first triangle hit. - - If no triangle is hit, the index is set to ``-1`` and - the distance is set to :data:`inf`. - """ - ray_origins = jnp.asarray(ray_origins) - ray_directions = jnp.asarray(ray_directions) - triangle_vertices = jnp.asarray(triangle_vertices) - - if active_triangles is not None: - active_triangles = jnp.asarray(active_triangles) - - dtype = jnp.result_type(ray_origins, ray_directions, triangle_vertices) - dist_tol = jnp.asarray(100 * jnp.finfo(dtype).eps) - - num_triangles = triangle_vertices.shape[-3] - - batch = jnp.broadcast_shapes( - ray_origins.shape[:-1], - ray_directions.shape[:-1], - triangle_vertices.shape[:-3], - active_triangles.shape[:-1] if active_triangles is not None else (), - ) - num_rays = math.prod(batch) - ray_batch_size = ( - ray_batch_size - if ray_batch_size is not None - else batch_size - if batch_size is not None - else num_rays - ) - ray_batch_size = min(ray_batch_size, num_rays) - tri_batch_size = ( - tri_batch_size - if tri_batch_size is not None - else batch_size - if batch_size is not None - else num_triangles - ) - tri_batch_size = min(tri_batch_size, num_triangles) - if num_triangles == 0: - # If there are no triangles, there are no hits - return ( - jnp.full(batch, -1, dtype=jnp.int32), - jnp.full( - batch, - jnp.inf, - dtype=jnp.result_type(ray_origins, ray_directions, triangle_vertices), - ), - ) - - xs = [] - argnames = [] - map_fn = partial( - _first_triangle_hit_by_ray, - batch_size=tri_batch_size, - dist_tol=dist_tol, - **kwargs, - ) - if math.prod(ray_origins.shape[:-1]) > 1: - xs.append(ray_origins.reshape(-1, 3)) - argnames.append("ray_origin") - else: - map_fn = partial(map_fn, ray_origin=ray_origins.reshape(3)) - if math.prod(ray_directions.shape[:-1]) > 1: - xs.append(ray_directions.reshape(-1, 3)) - argnames.append("ray_direction") - else: - map_fn = partial(map_fn, ray_direction=ray_directions.reshape(3)) - if math.prod(triangle_vertices.shape[:-3]) > 1: - xs.append(triangle_vertices.reshape(-1, num_triangles, 3, 3)) - argnames.append("triangle_vertices") - else: - map_fn = partial( - map_fn, triangle_vertices=triangle_vertices.reshape(num_triangles, 3, 3) - ) - if active_triangles is not None and math.prod(active_triangles.shape[:-1]) > 1: - xs.append(active_triangles.reshape(-1, num_triangles)) - argnames.append("active_triangles") - else: - map_fn = partial( - map_fn, - active_triangles=active_triangles.reshape(num_triangles) - if active_triangles is not None - else None, - ) - - def f(args: tuple[Array, ...]) -> tuple[Int[Array, ""], Float[Array, ""]]: - return map_fn(**dict(zip(argnames, args, strict=True))) - - if not xs: - indices, ts = _first_triangle_hit_by_ray( - ray_origins.reshape(3), - ray_directions.reshape(3), - triangle_vertices.reshape(num_triangles, 3, 3), - active_triangles.reshape(num_triangles) - if active_triangles is not None - else None, - batch_size=tri_batch_size, - dist_tol=dist_tol, - **kwargs, - ) - return indices.reshape(batch), ts.reshape(batch) - - if num_rays > ray_batch_size: - indices, ts = jax.lax.map(f, tuple(xs), batch_size=ray_batch_size) - else: - indices, ts = jax.vmap(f)(tuple(xs)) - - return indices.reshape(batch), ts.reshape(batch) From 9cb877b47de2ce2273cb8c29f74bfa040c5807f3 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?J=C3=A9rome=20Eertmans?= Date: Thu, 14 May 2026 11:26:42 +0200 Subject: [PATCH 13/13] some fixes --- differt/src/differt/rt/_utils.py | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/differt/src/differt/rt/_utils.py b/differt/src/differt/rt/_utils.py index 3a117e11..7fd135f4 100644 --- a/differt/src/differt/rt/_utils.py +++ b/differt/src/differt/rt/_utils.py @@ -575,7 +575,7 @@ def rays_intersect_any_triangle( if batch_size is not None else num_rays ) - ray_batch_size = min(ray_batch_size, num_rays) + ray_batch_size = max(min(ray_batch_size, num_rays), 1) tri_batch_size = ( tri_batch_size if tri_batch_size is not None @@ -583,7 +583,7 @@ def rays_intersect_any_triangle( if batch_size is not None else num_triangles ) - tri_batch_size = min(tri_batch_size, num_triangles) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) if num_triangles == 0: # If there are no triangles, there are no intersections return ( @@ -844,7 +844,7 @@ def first_triangles_hit_by_rays( :func:`rays_intersect_triangles`. Returns: - For each ray, return the index and to distance to the first triangle hit. + For each ray, return the index and the distance to the first triangle hit. If no triangle is hit, the index is set to ``-1`` and the distance is set to :data:`inf`. @@ -875,7 +875,7 @@ def first_triangles_hit_by_rays( if batch_size is not None else num_rays ) - ray_batch_size = min(ray_batch_size, num_rays) + ray_batch_size = max(min(ray_batch_size, num_rays), 1) tri_batch_size = ( tri_batch_size if tri_batch_size is not None @@ -883,7 +883,7 @@ def first_triangles_hit_by_rays( if batch_size is not None else num_triangles ) - tri_batch_size = min(tri_batch_size, num_triangles) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) if num_triangles == 0: # If there are no triangles, there are no hits return ( @@ -1159,7 +1159,7 @@ def triangles_visible_from_vertices( if batch_size is not None else num_rays ) - ray_batch_size = min(ray_batch_size, num_rays) + ray_batch_size = max(min(ray_batch_size, num_rays), 1) tri_batch_size = ( tri_batch_size if tri_batch_size is not None @@ -1167,7 +1167,7 @@ def triangles_visible_from_vertices( if batch_size is not None else num_triangles ) - tri_batch_size = min(tri_batch_size, num_triangles) + tri_batch_size = max(min(tri_batch_size, num_triangles), 1) if num_triangles == 0: return jnp.zeros((*batch, 0), dtype=jnp.bool_)