Terms
Description
The ray-intersect-any-triangle test is one of the main performance bottlenecks in this library when applied to a large number of rays and triangles. This is because jax.jit cannot optimize the jax.vmap-ed version of this check, resulting in the allocation of an array whose size scales with num_rays x num_triangles-too large to fit into any reasonable amount of memory.
To work around this, the current implementation uses a jax.lax.scan-like approach that processes triangles (or rays) sequentially, in user-defined batch sizes. The main drawbacks are:
- The optimal batch size depends on many factors, including the number of rays, the number of triangles, and the available memory, making it difficult to choose good default values.
jax.lax.scan-like solutions are orders of magnitude slower than their jax.vmap counterparts when array sizes are not prohibitively large.
Ideally, we would develop a jax.vmap solution that jax.jit can optimize in a way that avoids allocating the problematic intermediate array altogether. I have already spent quite some time investigating this, but so far, without success.
If you have ideas for improving this or would like to contribute, please share your thoughts or open a PR!
Screenshots
No response
Additional information
Related links:
Relevant functions:
Terms
Description
The ray-intersect-any-triangle test is one of the main performance bottlenecks in this library when applied to a large number of rays and triangles. This is because
jax.jitcannot optimize thejax.vmap-ed version of this check, resulting in the allocation of an array whose size scales withnum_rays x num_triangles-too large to fit into any reasonable amount of memory.To work around this, the current implementation uses a
jax.lax.scan-like approach that processes triangles (or rays) sequentially, in user-defined batch sizes. The main drawbacks are:jax.lax.scan-like solutions are orders of magnitude slower than theirjax.vmapcounterparts when array sizes are not prohibitively large.Ideally, we would develop a
jax.vmapsolution thatjax.jitcan optimize in a way that avoids allocating the problematic intermediate array altogether. I have already spent quite some time investigating this, but so far, without success.If you have ideas for improving this or would like to contribute, please share your thoughts or open a PR!
Screenshots
No response
Additional information
Related links:
lax.reduceNotImplementedError: Reduction computations can't close over Tracers jax-ml/jax#30841Refcould be interesting to look at: https://docs.jax.dev/en/latest/array_refs.htmlRelevant functions:
rays_intersect_any_triangle;triangles_visible_from_vertices;first_triangles_hit_by_rays.