From 6760d361fc4f9327e8ab1edbf1076488f500a403 Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 22:18:45 -0700 Subject: [PATCH 1/7] Migrate solve.py into respective detect/ modules --- docs/include_exclude.py | 1 - .../evolution/event_based/detect/ball_ball.py | 109 +++- .../event_based/detect/ball_cushion.py | 163 +++++- .../event_based/detect/ball_pocket.py | 177 +++++- .../event_based/detect/ball_table.py | 24 +- pooltool/physics/motion/__init__.py | 0 pooltool/physics/motion/solve.py | 534 ------------------ pooltool/physics/utils.py | 9 +- sandbox/airborne_demos.py | 2 +- .../evolution/event_based/test_ball_pocket.py | 2 +- tests/evolution/event_based/test_simulate.py | 2 +- 11 files changed, 470 insertions(+), 553 deletions(-) delete mode 100644 pooltool/physics/motion/__init__.py delete mode 100644 pooltool/physics/motion/solve.py diff --git a/docs/include_exclude.py b/docs/include_exclude.py index b6f3f514..98ca4012 100644 --- a/docs/include_exclude.py +++ b/docs/include_exclude.py @@ -29,7 +29,6 @@ "pooltool.ruleset.snooker", # API: pooltool.physics "pooltool.physics.evolve", - "pooltool.physics.motion", "pooltool.physics.resolve", ], "module": [ diff --git a/pooltool/evolution/event_based/detect/ball_ball.py b/pooltool/evolution/event_based/detect/ball_ball.py index 1796f7ea..1509360b 100644 --- a/pooltool/evolution/event_based/detect/ball_ball.py +++ b/pooltool/evolution/event_based/detect/ball_ball.py @@ -3,15 +3,122 @@ from itertools import combinations import numpy as np +from numba import jit +from numpy.typing import NDArray import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.events import Event, EventType, ball_ball_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ball_ball_collision_time +from pooltool.physics.utils import get_u_vec +from pooltool.ptmath.roots import quartic +from pooltool.ptmath.roots.core import get_real_positive_smallest_root from pooltool.system.datatypes import System +@jit(nopython=True, cache=const.use_numba_cache) +def ball_ball_collision_coeffs( + rvw1: NDArray[np.float64], + rvw2: NDArray[np.float64], + s1: int, + s2: int, + mu1: float, + mu2: float, + m1: float, + m2: float, + g1: float, + g2: float, + R: float, +) -> tuple[float, float, float, float, float]: + """Get quartic coeffs required to determine the ball-ball collision time + + (just-in-time compiled) + """ + + c1x, c1y = rvw1[0, 0], rvw1[0, 1] + c2x, c2y = rvw2[0, 0], rvw2[0, 1] + + if s1 == const.spinning or s1 == const.pocketed or s1 == const.stationary: + a1x, a1y, b1x, b1y = 0, 0, 0, 0 + else: + phi1 = ptmath.angle(rvw1[1]) + v1 = ptmath.norm3d(rvw1[1]) + + u1 = get_u_vec(rvw1, R, phi1, s1) + + K1 = -0.5 * mu1 * g1 + cos_phi1 = np.cos(phi1) + sin_phi1 = np.sin(phi1) + + a1x = K1 * (u1[0] * cos_phi1 - u1[1] * sin_phi1) + a1y = K1 * (u1[0] * sin_phi1 + u1[1] * cos_phi1) + b1x = v1 * cos_phi1 + b1y = v1 * sin_phi1 + + if s2 == const.spinning or s2 == const.pocketed or s2 == const.stationary: + a2x, a2y, b2x, b2y = 0.0, 0.0, 0.0, 0.0 + else: + phi2 = ptmath.angle(rvw2[1]) + v2 = ptmath.norm3d(rvw2[1]) + + u2 = get_u_vec(rvw2, R, phi2, s2) + + K2 = -0.5 * mu2 * g2 + cos_phi2 = np.cos(phi2) + sin_phi2 = np.sin(phi2) + + a2x = K2 * (u2[0] * cos_phi2 - u2[1] * sin_phi2) + a2y = K2 * (u2[0] * sin_phi2 + u2[1] * cos_phi2) + b2x = v2 * cos_phi2 + b2y = v2 * sin_phi2 + + Ax, Ay = a2x - a1x, a2y - a1y + Bx, By = b2x - b1x, b2y - b1y + Cx, Cy = c2x - c1x, c2y - c1y + + a = Ax * Ax + Ay * Ay + b = 2 * Ax * Bx + 2 * Ay * By + c = Bx * Bx + 2 * Ax * Cx + 2 * Ay * Cy + By * By + d = 2 * Bx * Cx + 2 * By * Cy + e = Cx * Cx + Cy * Cy - 4 * R * R + + return a, b, c, d, e + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_ball_collision_time( + rvw1: NDArray[np.float64], + rvw2: NDArray[np.float64], + s1: int, + s2: int, + mu1: float, + mu2: float, + m1: float, + m2: float, + g1: float, + g2: float, + R: float, +) -> float: + """Get the time until collision between 2 balls.""" + return get_real_positive_smallest_root( + quartic.solve( + *ball_ball_collision_coeffs( + rvw1, + rvw2, + s1, + s2, + mu1, + mu2, + m1, + m2, + g1, + g2, + R, + ) + ) + ) + + def get_next_ball_ball_2d_event(shot: System, collision_cache: CollisionCache) -> Event: """Detect the next ball-ball collision in 2D mode.""" cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py index 6eaf494e..9ce1c50b 100644 --- a/pooltool/evolution/event_based/detect/ball_cushion.py +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -1,8 +1,12 @@ from __future__ import annotations import numpy as np +from numba import jit +from numpy.typing import NDArray import pooltool.constants as const +import pooltool.physics.evolve as evolve +import pooltool.ptmath as ptmath from pooltool.events import ( Event, EventType, @@ -11,13 +15,164 @@ null_event, ) from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ( - ball_circular_cushion_collision_time, - ball_linear_cushion_collision_time, -) +from pooltool.physics.utils import get_u_vec +from pooltool.ptmath.roots import quartic +from pooltool.ptmath.roots.core import get_real_positive_smallest_root from pooltool.system.datatypes import System +@jit(nopython=True, cache=const.use_numba_cache) +def ball_linear_cushion_collision_time( + rvw: NDArray[np.float64], + s: int, + lx: float, + ly: float, + l0: float, + p1: NDArray[np.float64], + p2: NDArray[np.float64], + direction: int, + mu: float, + m: float, + g: float, + R: float, +) -> float: + """Get time until collision between ball and linear cushion segment + + (just-in-time compiled) + """ + if s == const.spinning or s == const.pocketed or s == const.stationary: + return np.inf + + phi = ptmath.angle(rvw[1]) + v = ptmath.norm3d(rvw[1]) + + u = get_u_vec(rvw, R, phi, s) + + K = -0.5 * mu * g + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + + ax = K * (u[0] * cos_phi - u[1] * sin_phi) + ay = K * (u[0] * sin_phi + u[1] * cos_phi) + bx, by = v * cos_phi, v * sin_phi + cx, cy = rvw[0, 0], rvw[0, 1] + + A = lx * ax + ly * ay + B = lx * bx + ly * by + + if direction == 0: + C = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) + root1, root2 = ptmath.roots.quadratic.solve(A, B, C) + roots = [root1, root2] + elif direction == 1: + C = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) + root1, root2 = ptmath.roots.quadratic.solve(A, B, C) + roots = [root1, root2] + else: + C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) + C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) + root1, root2 = ptmath.roots.quadratic.solve(A, B, C1) + root3, root4 = ptmath.roots.quadratic.solve(A, B, C2) + roots = [root1, root2, root3, root4] + + min_time = np.inf + for root in roots: + if np.isnan(root): + continue + + if np.abs(root.imag) > const.EPS: + continue + + if root.real <= const.EPS: + continue + + rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real) + s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1) + + if not (0 <= s_score <= 1): + continue + + if root.real < min_time: + min_time = root.real + + return min_time + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_circular_cushion_collision_coeffs( + rvw: NDArray[np.float64], + s: int, + a: float, + b: float, + r: float, + mu: float, + m: float, + g: float, + R: float, +) -> tuple[float, float, float, float, float]: + """Get quartic coeffs required to determine the ball-circular-cushion collision time + + (just-in-time compiled) + """ + + if s == const.spinning or s == const.pocketed or s == const.stationary: + return np.inf, np.inf, np.inf, np.inf, np.inf + + phi = ptmath.angle(rvw[1]) + v = ptmath.norm3d(rvw[1]) + + u = get_u_vec(rvw, R, phi, s) + + K = -0.5 * mu * g + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + + ax = K * (u[0] * cos_phi - u[1] * sin_phi) + ay = K * (u[0] * sin_phi + u[1] * cos_phi) + bx, by = v * cos_phi, v * sin_phi + cx, cy = rvw[0, 0], rvw[0, 1] + + A = 0.5 * (ax * ax + ay * ay) + B = ax * bx + ay * by + C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) + D = bx * (cx - a) + by * (cy - b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - (r + R) * (r + R)) - ( + cx * a + cy * b + ) + + return A, B, C, D, E + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_circular_cushion_collision_time( + rvw: NDArray[np.float64], + s: int, + a: float, + b: float, + r: float, + mu: float, + m: float, + g: float, + R: float, +) -> float: + """Get the time until collision between a ball and a circular cushion segment.""" + return get_real_positive_smallest_root( + quartic.solve( + *ball_circular_cushion_collision_coeffs( + rvw, + s, + a, + b, + r, + mu, + m, + g, + R, + ) + ) + ) + + def get_next_ball_linear_cushion_2d_event( shot: System, collision_cache: CollisionCache ) -> Event: diff --git a/pooltool/evolution/event_based/detect/ball_pocket.py b/pooltool/evolution/event_based/detect/ball_pocket.py index 198f095b..276f689c 100644 --- a/pooltool/evolution/event_based/detect/ball_pocket.py +++ b/pooltool/evolution/event_based/detect/ball_pocket.py @@ -1,17 +1,186 @@ from __future__ import annotations import numpy as np +from numba import jit +from numpy.typing import NDArray import pooltool.constants as const +import pooltool.ptmath as ptmath from pooltool.events import Event, EventType, ball_pocket_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ( - ball_pocket_collision_time, - ball_pocket_collision_time_airborne, -) +from pooltool.physics.utils import get_airborne_time, get_u_vec +from pooltool.ptmath.roots import quadratic, quartic +from pooltool.ptmath.roots.core import get_real_positive_smallest_root from pooltool.system.datatypes import System +@jit(nopython=True, cache=const.use_numba_cache) +def ball_pocket_collision_coeffs( + rvw: NDArray[np.float64], + s: int, + a: float, + b: float, + r: float, + mu: float, + m: float, + g: float, + R: float, +) -> tuple[float, float, float, float, float]: + """Get quartic coeffs required to determine the ball-pocket collision time + + (just-in-time compiled) + """ + + if s == const.spinning or s == const.pocketed or s == const.stationary: + return np.inf, np.inf, np.inf, np.inf, np.inf + + phi = ptmath.angle(rvw[1]) + v = ptmath.norm3d(rvw[1]) + + u = get_u_vec(rvw, R, phi, s) + + K = -0.5 * mu * g + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + + ax = K * (u[0] * cos_phi - u[1] * sin_phi) + ay = K * (u[0] * sin_phi + u[1] * cos_phi) + bx, by = v * cos_phi, v * sin_phi + cx, cy = rvw[0, 0], rvw[0, 1] + + A = 0.5 * (ax * ax + ay * ay) + B = ax * bx + ay * by + C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) + D = bx * (cx - a) + by * (cy - b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) + + return A, B, C, D, E + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_pocket_collision_time( + rvw: NDArray[np.float64], + s: int, + a: float, + b: float, + r: float, + mu: float, + m: float, + g: float, + R: float, +) -> float: + """Get the time until collision between a ball and a pocket.""" + return get_real_positive_smallest_root( + quartic.solve( + *ball_pocket_collision_coeffs( + rvw, + s, + a, + b, + r, + mu, + m, + g, + R, + ) + ) + ) + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_pocket_collision_time_airborne( + rvw: NDArray[np.float64], + a: float, + b: float, + r: float, + g: float, + R: float, +) -> float: + """Determine the ball-pocket collision time for an airborne ball. + + The behavior is somewhat complicated. Here is the procedure. + + Strategy 1: The xy-coordinates of where the ball lands are calculated. If that falls + within the pocket circle, a collision is returned. The collision time is chosen to + be just less than the collision time for the table collision, to guarantee temporal + precedence over the table collision. + + Strategy 2: Otherwise, the influx and outflux collision times are calculated between + the ball center and a vertical cylinder that extends from the pocket's circle. + Influx collision refers to the collision with the outside of the cylinder's wall. + The outflux collision refers to the collision with the inside of the cylinder's wall + and occurs later in time. Since there is no deceleration in the xy-plane for an + airborne ball, an outflux collision is expected, meaning we expect 2 finite roots. + (This is only violated if the ball starts inside the cylinder, which results in at + most an outflux collision). The strategy is to see what the ball height is at the + time of the influx collision (``h0``) and the outflux collision (``hf``), because + from these we can determine whether or not the ball is considered to enter the + pocket. The following logic is used: + + - ``h0 < R``: The ball passes through the playing surface plane before + intersecting the pocket cylinder, guaranteeing that a ball-table collision + occurs. Infinity is returned. + - ``hf <= (7/5)*R``: If the outflux height is less than ``(7/5)*R``, the ball + is considered to be pocketed. This threshold height implicitly models the + fact that high velocity balls that are slightly airborne collide with table + geometry at the back of the pocket, ricocheting the ball into the pocket. + The average of the influx and outflux collision times is returned. + - ``hf > (7/5)*R``: The ball is considered to fly over the pocket. Infinity is + returned. + """ + phi = ptmath.angle(rvw[1]) + v = ptmath.norm2d(rvw[1]) + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + bx, by = v * cos_phi, v * sin_phi + + # Strategy 1: does the ball land inside the pocket circle? + airborne_time = get_airborne_time(rvw, R, g) + x = rvw[0, 0] + bx * airborne_time + y = rvw[0, 1] + by * airborne_time + + if (x - a) ** 2 + (y - b) ** 2 < r * r: + return float(airborne_time - const.EPS) + + # Strategy 2: does the ball's xy trajectory cross the pocket cylinder? + cx, cy = rvw[0, 0], rvw[0, 1] + + # These match the non-airborne quartic coefficients, after setting ax=ay=0. + C = 0.5 * (bx * bx + by * by) + D = bx * (cx - a) + by * (cy - b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) + + roots = quadratic.solve(C, D, E) + + atol = 1e-9 + if abs(roots[0].imag) > atol or abs(roots[1].imag) > atol: + return np.inf + + real_0 = roots[0].real + real_1 = roots[1].real + if real_0 <= real_1: + r1, r2 = real_0, real_1 + else: + r1, r2 = real_1, real_0 + + if r1 < 0.0: + return np.inf + + v0z = rvw[1, 2] + z0 = rvw[0, 2] + + h0 = -0.5 * g * r1 * r1 + v0z * r1 + z0 + hf = -0.5 * g * r2 * r2 + v0z * r2 + z0 + + if h0 < R: + return np.inf + + if hf > 7.0 / 5.0 * R: + return np.inf + + return (r1 + r2) / 2.0 + + def get_next_ball_pocket_2d_event( shot: System, collision_cache: CollisionCache ) -> Event: diff --git a/pooltool/evolution/event_based/detect/ball_table.py b/pooltool/evolution/event_based/detect/ball_table.py index df74fb54..2ca964c1 100644 --- a/pooltool/evolution/event_based/detect/ball_table.py +++ b/pooltool/evolution/event_based/detect/ball_table.py @@ -1,11 +1,33 @@ from __future__ import annotations +import numpy as np +from numba import jit +from numpy.typing import NDArray + +import pooltool.constants as const from pooltool.events import Event, EventType, ball_table_collision from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ball_table_collision_time +from pooltool.physics.utils import get_airborne_time from pooltool.system.datatypes import System +@jit(nopython=True, cache=const.use_numba_cache) +def ball_table_collision_time( + rvw: NDArray[np.float64], + s: int, + g: float, + R: float, +) -> float: + """Time until an airborne ball's bottom touches the table plane. + + Returns ``np.inf`` if the ball is not airborne (no ball-table collision can + occur for any other motion state). + """ + if s != const.airborne: + return np.inf + return get_airborne_time(rvw=rvw, R=R, g=g) + + def get_next_ball_table_event(shot: System, collision_cache: CollisionCache) -> Event: """Detect the next ball-table collision (airborne ball landing). diff --git a/pooltool/physics/motion/__init__.py b/pooltool/physics/motion/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pooltool/physics/motion/solve.py b/pooltool/physics/motion/solve.py deleted file mode 100644 index 53905724..00000000 --- a/pooltool/physics/motion/solve.py +++ /dev/null @@ -1,534 +0,0 @@ -from math import acos - -import numpy as np -from numba import jit -from numpy.typing import NDArray - -import pooltool.constants as const -import pooltool.physics.evolve as evolve -import pooltool.ptmath as ptmath -from pooltool.physics.utils import get_airborne_time, rel_velocity -from pooltool.ptmath.roots import quadratic, quartic -from pooltool.ptmath.roots.core import get_real_positive_smallest_root - - -@jit(nopython=True, cache=const.use_numba_cache) -def skip_ball_ball_collision( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - R1: float, - R2: float, -) -> bool: - if (s1 == const.spinning or s1 == const.pocketed or s1 == const.stationary) and ( - s2 == const.spinning or s2 == const.pocketed or s2 == const.stationary - ): - # Neither balls are moving. No collision. - return True - - if s1 == const.pocketed or s2 == const.pocketed: - # One of the balls is pocketed - return True - - if s1 == const.rolling and s2 == const.rolling: - # Both balls are rolling (straight line trajectories). Here I am checking - # whether both dot products face away from the line connecting the two balls. If - # so, they are guaranteed not to collide - r12 = rvw2[0] - rvw1[0] - dot1 = r12[0] * rvw1[1, 0] + r12[1] * rvw1[1, 1] + r12[2] * rvw1[1, 2] - if dot1 <= 0: - dot2 = r12[0] * rvw2[1, 0] + r12[1] * rvw2[1, 1] + r12[2] * rvw2[1, 2] - if dot2 >= 0: - return True - - if s1 == const.rolling and (s2 == const.spinning or s2 == const.stationary): - # ball1 is rolling, which guarantees a straight-line trajectory. Some - # assumptions can be made based on this fact - r12 = rvw2[0] - rvw1[0] - - # ball2 is not moving, so we can pinpoint the range of angles ball1 must be - # headed in for a collision - d = ptmath.norm3d(r12) - unit_d = r12 / d - unit_v = ptmath.unit_vector(rvw1[1]) - - # Angles are in radians - # Calculate forwards and backwards angles, e.g. 10 and 350, take the min - angle = np.arccos(np.dot(unit_d, unit_v)) - max_hit_angle = 0.5 * np.pi - acos((R1 + R2) / d) - if angle > max_hit_angle: - return True - - if s2 == const.rolling and (s1 == const.spinning or s1 == const.stationary): - # ball2 is rolling, which guarantees a straight-line trajectory. Some - # assumptions can be made based on this fact - r21 = rvw1[0] - rvw2[0] - - # ball1 is not moving, so we can pinpoint the range of angles ball2 must be - # headed in for a collision - d = ptmath.norm3d(r21) - unit_d = r21 / d - unit_v = ptmath.unit_vector(rvw2[1]) - - # Angles are in radians - # Calculate forwards and backwards angles, e.g. 10 and 350, take the min - angle = np.arccos(np.dot(unit_d, unit_v)) - max_hit_angle = 0.5 * np.pi - acos((R1 + R2) / d) - if angle > max_hit_angle: - return True - - return False - - -@jit(nopython=True, cache=const.use_numba_cache) -def get_u( - rvw: NDArray[np.float64], R: float, phi: float, s: int -) -> NDArray[np.float64]: - if s == const.rolling: - return np.array([1, 0, 0], dtype=np.float64) - - rel_vel = rel_velocity(rvw, R) - if (rel_vel == 0).all(): - return np.array([1, 0, 0], dtype=np.float64) - - return ptmath.coordinate_rotation(ptmath.unit_vector(rel_vel), -phi) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_ball_collision_coeffs( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - mu1: float, - mu2: float, - m1: float, - m2: float, - g1: float, - g2: float, - R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-ball collision time - - (just-in-time compiled) - """ - - c1x, c1y = rvw1[0, 0], rvw1[0, 1] - c2x, c2y = rvw2[0, 0], rvw2[0, 1] - - if s1 == const.spinning or s1 == const.pocketed or s1 == const.stationary: - a1x, a1y, b1x, b1y = 0, 0, 0, 0 - else: - phi1 = ptmath.angle(rvw1[1]) - v1 = ptmath.norm3d(rvw1[1]) - - u1 = get_u(rvw1, R, phi1, s1) - - K1 = -0.5 * mu1 * g1 - cos_phi1 = np.cos(phi1) - sin_phi1 = np.sin(phi1) - - a1x = K1 * (u1[0] * cos_phi1 - u1[1] * sin_phi1) - a1y = K1 * (u1[0] * sin_phi1 + u1[1] * cos_phi1) - b1x = v1 * cos_phi1 - b1y = v1 * sin_phi1 - - if s2 == const.spinning or s2 == const.pocketed or s2 == const.stationary: - a2x, a2y, b2x, b2y = 0.0, 0.0, 0.0, 0.0 - else: - phi2 = ptmath.angle(rvw2[1]) - v2 = ptmath.norm3d(rvw2[1]) - - u2 = get_u(rvw2, R, phi2, s2) - - K2 = -0.5 * mu2 * g2 - cos_phi2 = np.cos(phi2) - sin_phi2 = np.sin(phi2) - - a2x = K2 * (u2[0] * cos_phi2 - u2[1] * sin_phi2) - a2y = K2 * (u2[0] * sin_phi2 + u2[1] * cos_phi2) - b2x = v2 * cos_phi2 - b2y = v2 * sin_phi2 - - Ax, Ay = a2x - a1x, a2y - a1y - Bx, By = b2x - b1x, b2y - b1y - Cx, Cy = c2x - c1x, c2y - c1y - - a = Ax * Ax + Ay * Ay - b = 2 * Ax * Bx + 2 * Ay * By - c = Bx * Bx + 2 * Ax * Cx + 2 * Ay * Cy + By * By - d = 2 * Bx * Cx + 2 * By * Cy - e = Cx * Cx + Cy * Cy - 4 * R * R - - return a, b, c, d, e - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_ball_collision_time( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - mu1: float, - mu2: float, - m1: float, - m2: float, - g1: float, - g2: float, - R: float, -) -> float: - """Get the time until collision between 2 balls.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_ball_collision_coeffs( - rvw1, - rvw2, - s1, - s2, - mu1, - mu2, - m1, - m2, - g1, - g2, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_linear_cushion_collision_time( - rvw: NDArray[np.float64], - s: int, - lx: float, - ly: float, - l0: float, - p1: NDArray[np.float64], - p2: NDArray[np.float64], - direction: int, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get time until collision between ball and linear cushion segment - - (just-in-time compiled) - """ - if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf - - phi = ptmath.angle(rvw[1]) - v = ptmath.norm3d(rvw[1]) - - u = get_u(rvw, R, phi, s) - - K = -0.5 * mu * g - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - - ax = K * (u[0] * cos_phi - u[1] * sin_phi) - ay = K * (u[0] * sin_phi + u[1] * cos_phi) - bx, by = v * cos_phi, v * sin_phi - cx, cy = rvw[0, 0], rvw[0, 1] - - A = lx * ax + ly * ay - B = lx * bx + ly * by - - if direction == 0: - C = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) - root1, root2 = ptmath.roots.quadratic.solve(A, B, C) - roots = [root1, root2] - elif direction == 1: - C = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) - root1, root2 = ptmath.roots.quadratic.solve(A, B, C) - roots = [root1, root2] - else: - C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) - C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) - root1, root2 = ptmath.roots.quadratic.solve(A, B, C1) - root3, root4 = ptmath.roots.quadratic.solve(A, B, C2) - roots = [root1, root2, root3, root4] - - min_time = np.inf - for root in roots: - if np.isnan(root): - continue - - if np.abs(root.imag) > const.EPS: - continue - - if root.real <= const.EPS: - continue - - rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real) - s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1) - - if not (0 <= s_score <= 1): - continue - - if root.real < min_time: - min_time = root.real - - return min_time - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_circular_cushion_collision_coeffs( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-circular-cushion collision time - - (just-in-time compiled) - """ - - if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf, np.inf, np.inf, np.inf, np.inf - - phi = ptmath.angle(rvw[1]) - v = ptmath.norm3d(rvw[1]) - - u = get_u(rvw, R, phi, s) - - K = -0.5 * mu * g - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - - ax = K * (u[0] * cos_phi - u[1] * sin_phi) - ay = K * (u[0] * sin_phi + u[1] * cos_phi) - bx, by = v * cos_phi, v * sin_phi - cx, cy = rvw[0, 0], rvw[0, 1] - - A = 0.5 * (ax * ax + ay * ay) - B = ax * bx + ay * by - C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) - D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a * a + b * b + cx * cx + cy * cy - (r + R) * (r + R)) - ( - cx * a + cy * b - ) - - return A, B, C, D, E - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_circular_cushion_collision_time( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get the time until collision between a ball and a circular cushion segment.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_circular_cushion_collision_coeffs( - rvw, - s, - a, - b, - r, - mu, - m, - g, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_coeffs( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-pocket collision time - - (just-in-time compiled) - """ - - if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf, np.inf, np.inf, np.inf, np.inf - - phi = ptmath.angle(rvw[1]) - v = ptmath.norm3d(rvw[1]) - - u = get_u(rvw, R, phi, s) - - K = -0.5 * mu * g - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - - ax = K * (u[0] * cos_phi - u[1] * sin_phi) - ay = K * (u[0] * sin_phi + u[1] * cos_phi) - bx, by = v * cos_phi, v * sin_phi - cx, cy = rvw[0, 0], rvw[0, 1] - - A = 0.5 * (ax * ax + ay * ay) - B = ax * bx + ay * by - C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) - D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) - - return A, B, C, D, E - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_time( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get the time until collision between a ball and a pocket.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_pocket_collision_coeffs( - rvw, - s, - a, - b, - r, - mu, - m, - g, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_time_airborne( - rvw: NDArray[np.float64], - a: float, - b: float, - r: float, - g: float, - R: float, -) -> float: - """Determine the ball-pocket collision time for an airborne ball. - - The behavior is somewhat complicated. Here is the procedure. - - Strategy 1: The xy-coordinates of where the ball lands are calculated. If that falls - within the pocket circle, a collision is returned. The collision time is chosen to - be just less than the collision time for the table collision, to guarantee temporal - precedence over the table collision. - - Strategy 2: Otherwise, the influx and outflux collision times are calculated between - the ball center and a vertical cylinder that extends from the pocket's circle. - Influx collision refers to the collision with the outside of the cylinder's wall. - The outflux collision refers to the collision with the inside of the cylinder's wall - and occurs later in time. Since there is no deceleration in the xy-plane for an - airborne ball, an outflux collision is expected, meaning we expect 2 finite roots. - (This is only violated if the ball starts inside the cylinder, which results in at - most an outflux collision). The strategy is to see what the ball height is at the - time of the influx collision (``h0``) and the outflux collision (``hf``), because - from these we can determine whether or not the ball is considered to enter the - pocket. The following logic is used: - - - ``h0 < R``: The ball passes through the playing surface plane before - intersecting the pocket cylinder, guaranteeing that a ball-table collision - occurs. Infinity is returned. - - ``hf <= (7/5)*R``: If the outflux height is less than ``(7/5)*R``, the ball - is considered to be pocketed. This threshold height implicitly models the - fact that high velocity balls that are slightly airborne collide with table - geometry at the back of the pocket, ricocheting the ball into the pocket. - The average of the influx and outflux collision times is returned. - - ``hf > (7/5)*R``: The ball is considered to fly over the pocket. Infinity is - returned. - """ - phi = ptmath.angle(rvw[1]) - v = ptmath.norm2d(rvw[1]) - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - bx, by = v * cos_phi, v * sin_phi - - # Strategy 1: does the ball land inside the pocket circle? - airborne_time = get_airborne_time(rvw, R, g) - x = rvw[0, 0] + bx * airborne_time - y = rvw[0, 1] + by * airborne_time - - if (x - a) ** 2 + (y - b) ** 2 < r * r: - return float(airborne_time - const.EPS) - - # Strategy 2: does the ball's xy trajectory cross the pocket cylinder? - cx, cy = rvw[0, 0], rvw[0, 1] - - # These match the non-airborne quartic coefficients, after setting ax=ay=0. - C = 0.5 * (bx * bx + by * by) - D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) - - roots = quadratic.solve(C, D, E) - - atol = 1e-9 - if abs(roots[0].imag) > atol or abs(roots[1].imag) > atol: - return np.inf - - real_0 = roots[0].real - real_1 = roots[1].real - if real_0 <= real_1: - r1, r2 = real_0, real_1 - else: - r1, r2 = real_1, real_0 - - if r1 < 0.0: - return np.inf - - v0z = rvw[1, 2] - z0 = rvw[0, 2] - - h0 = -0.5 * g * r1 * r1 + v0z * r1 + z0 - hf = -0.5 * g * r2 * r2 + v0z * r2 + z0 - - if h0 < R: - return np.inf - - if hf > 7.0 / 5.0 * R: - return np.inf - - return (r1 + r2) / 2.0 - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_table_collision_time( - rvw: NDArray[np.float64], - s: int, - g: float, - R: float, -) -> float: - """Time until an airborne ball's bottom touches the table plane. - - Returns ``np.inf`` if the ball is not airborne (no ball-table collision can - occur for any other motion state). - """ - if s != const.airborne: - return np.inf - return get_airborne_time(rvw=rvw, R=R, g=g) diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 1208f698..6ffdcd5f 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -37,15 +37,14 @@ def rel_velocity(rvw: NDArray[np.float64], R: float) -> NDArray[np.float64]: @jit(nopython=True, cache=const.use_numba_cache) def get_u_vec( - rvw: NDArray[np.float64], phi: float, R: float, s: int + rvw: NDArray[np.float64], R: float, phi: float, s: int ) -> NDArray[np.float64]: if s == const.rolling: - return np.array([1.0, 0.0, 0.0]) + return np.array([1, 0, 0], dtype=np.float64) rel_vel = rel_velocity(rvw, R) - - if (rel_vel == 0.0).all(): - return np.array([1.0, 0.0, 0.0]) + if (rel_vel == 0).all(): + return np.array([1, 0, 0], dtype=np.float64) return coordinate_rotation(unit_vector(rel_vel), -phi) diff --git a/sandbox/airborne_demos.py b/sandbox/airborne_demos.py index a98e10d6..70a8147e 100644 --- a/sandbox/airborne_demos.py +++ b/sandbox/airborne_demos.py @@ -97,7 +97,7 @@ def airborne_pocket_collision() -> System: Exercises both Strategy 1 (landing directly inside the pocket) and Strategy 2 (xy trajectory crossing the pocket cylinder mid-fall) in - :func:`pooltool.physics.motion.solve.ball_pocket_collision_time_airborne`. + :func:`pooltool.evolution.event_based.detect.ball_pocket.ball_pocket_collision_time_airborne`. """ ball = Ball.create("cue") scale = 0.18 diff --git a/tests/evolution/event_based/test_ball_pocket.py b/tests/evolution/event_based/test_ball_pocket.py index b7bb33dd..9720a521 100644 --- a/tests/evolution/event_based/test_ball_pocket.py +++ b/tests/evolution/event_based/test_ball_pocket.py @@ -6,6 +6,7 @@ from pooltool.events import EventType, filter_type from pooltool.evolution.event_based.cache import CollisionCache from pooltool.evolution.event_based.detect.ball_pocket import ( + ball_pocket_collision_time_airborne, get_next_ball_pocket_3d_event, ) from pooltool.evolution.event_based.simulate import simulate @@ -13,7 +14,6 @@ from pooltool.objects.cue.datatypes import Cue from pooltool.objects.table.datatypes import Table from pooltool.objects.table.specs import TableType -from pooltool.physics.motion.solve import ball_pocket_collision_time_airborne from pooltool.physics.utils import get_airborne_time from pooltool.system.datatypes import System diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index da9291f3..b1bd6e1c 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -13,11 +13,11 @@ EventDetector, get_next_ball_ball_2d_event, ) +from pooltool.evolution.event_based.detect.ball_ball import ball_ball_collision_time from pooltool.evolution.event_based.simulate import simulate from pooltool.objects import Ball, BilliardTableSpecs, Cue, Table from pooltool.objects.ball.params import BallParams from pooltool.objects.ball.sets import BallSet -from pooltool.physics.motion.solve import ball_ball_collision_time from pooltool.ptmath.roots import quadratic from pooltool.system import System from tests.evolution.event_based.test_data import TEST_DIR From 88d9b2c6a507a6ed1170a94167f900fcee49a7cb Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 22:23:07 -0700 Subject: [PATCH 2/7] Surface get_real_positive_smallest_root in roots subpkg --- pooltool/evolution/event_based/detect/ball_cushion.py | 3 +-- pooltool/evolution/event_based/detect/ball_pocket.py | 3 +-- pooltool/ptmath/roots/__init__.py | 6 ++++++ 3 files changed, 8 insertions(+), 4 deletions(-) diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py index 9ce1c50b..08ef6cc5 100644 --- a/pooltool/evolution/event_based/detect/ball_cushion.py +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -16,8 +16,7 @@ ) from pooltool.evolution.event_based.cache import CollisionCache from pooltool.physics.utils import get_u_vec -from pooltool.ptmath.roots import quartic -from pooltool.ptmath.roots.core import get_real_positive_smallest_root +from pooltool.ptmath.roots import get_real_positive_smallest_root, quartic from pooltool.system.datatypes import System diff --git a/pooltool/evolution/event_based/detect/ball_pocket.py b/pooltool/evolution/event_based/detect/ball_pocket.py index 276f689c..9cc134ca 100644 --- a/pooltool/evolution/event_based/detect/ball_pocket.py +++ b/pooltool/evolution/event_based/detect/ball_pocket.py @@ -9,8 +9,7 @@ from pooltool.events import Event, EventType, ball_pocket_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache from pooltool.physics.utils import get_airborne_time, get_u_vec -from pooltool.ptmath.roots import quadratic, quartic -from pooltool.ptmath.roots.core import get_real_positive_smallest_root +from pooltool.ptmath.roots import get_real_positive_smallest_root, quadratic, quartic from pooltool.system.datatypes import System diff --git a/pooltool/ptmath/roots/__init__.py b/pooltool/ptmath/roots/__init__.py index 3b07653d..7e7f65c5 100644 --- a/pooltool/ptmath/roots/__init__.py +++ b/pooltool/ptmath/roots/__init__.py @@ -1,7 +1,13 @@ import pooltool.ptmath.roots.quadratic as quadratic import pooltool.ptmath.roots.quartic as quartic +from pooltool.ptmath.roots.core import ( + get_real_positive_smallest_root, + get_real_positive_smallest_roots, +) __all__ = [ "quadratic", "quartic", + "get_real_positive_smallest_root", + "get_real_positive_smallest_roots", ] From fac7c35121a283e5e2a464668393e70cee04472b Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 22:43:17 -0700 Subject: [PATCH 3/7] Clean up ball-pocket detection --- .../evolution/event_based/detect/__init__.py | 6 +- .../event_based/detect/ball_pocket.py | 118 +++--------------- .../evolution/event_based/detect/detector.py | 7 +- .../evolution/event_based/test_ball_pocket.py | 12 +- 4 files changed, 27 insertions(+), 116 deletions(-) diff --git a/pooltool/evolution/event_based/detect/__init__.py b/pooltool/evolution/event_based/detect/__init__.py index c645826a..b582406b 100644 --- a/pooltool/evolution/event_based/detect/__init__.py +++ b/pooltool/evolution/event_based/detect/__init__.py @@ -9,8 +9,7 @@ get_next_ball_linear_cushion_3d_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( - get_next_ball_pocket_2d_event, - get_next_ball_pocket_3d_event, + get_next_ball_pocket_event, ) from pooltool.evolution.event_based.detect.ball_table import ( get_next_ball_table_event, @@ -28,8 +27,7 @@ "get_next_ball_circular_cushion_3d_event", "get_next_ball_linear_cushion_2d_event", "get_next_ball_linear_cushion_3d_event", - "get_next_ball_pocket_2d_event", - "get_next_ball_pocket_3d_event", + "get_next_ball_pocket_event", "get_next_ball_table_event", "get_next_stick_ball_event", ] diff --git a/pooltool/evolution/event_based/detect/ball_pocket.py b/pooltool/evolution/event_based/detect/ball_pocket.py index 9cc134ca..cf70444a 100644 --- a/pooltool/evolution/event_based/detect/ball_pocket.py +++ b/pooltool/evolution/event_based/detect/ball_pocket.py @@ -14,7 +14,7 @@ @jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_coeffs( +def ball_pocket_collision_time( rvw: NDArray[np.float64], s: int, a: float, @@ -24,14 +24,18 @@ def ball_pocket_collision_coeffs( m: float, g: float, R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-pocket collision time +) -> float: + """Get the time until collision between a ball and a pocket. - (just-in-time compiled) + Contains branching logic depending on whether the ball is airborne. """ if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf, np.inf, np.inf, np.inf, np.inf + return np.inf + + if s == const.airborne: + # Special treatment if ball is airborne + return ball_pocket_collision_time_if_airborne(rvw, a, b, r, g, R) phi = ptmath.angle(rvw[1]) v = ptmath.norm3d(rvw[1]) @@ -53,41 +57,11 @@ def ball_pocket_collision_coeffs( D = bx * (cx - a) + by * (cy - b) E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) - return A, B, C, D, E + return get_real_positive_smallest_root(quartic.solve(A, B, C, D, E)) @jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_time( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get the time until collision between a ball and a pocket.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_pocket_collision_coeffs( - rvw, - s, - a, - b, - r, - mu, - m, - g, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_time_airborne( +def ball_pocket_collision_time_if_airborne( rvw: NDArray[np.float64], a: float, b: float, @@ -180,10 +154,11 @@ def ball_pocket_collision_time_airborne( return (r1 + r2) / 2.0 -def get_next_ball_pocket_2d_event( - shot: System, collision_cache: CollisionCache +def get_next_ball_pocket_event( + shot: System, + collision_cache: CollisionCache, ) -> Event: - """Detect the next ball-pocket collision in 2D mode.""" + """Detect the next ball-pocket collision.""" if not shot.table.has_pockets: return null_event(np.inf) @@ -199,7 +174,7 @@ def get_next_ball_pocket_2d_event( if obj_ids in cache: continue - if ball.state.s in const.nontranslating: + if state.s in const.nontranslating: cache[obj_ids] = np.inf continue @@ -223,64 +198,3 @@ def get_next_ball_pocket_2d_event( pocket=shot.table.pockets[pocket_id], time=cache[(ball_id, pocket_id)], ) - - -def get_next_ball_pocket_3d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - """Detect the next ball-pocket collision in 3D mode. - - Airborne balls use :func:`ball_pocket_collision_time_airborne`, which models the - pocket as a vertical cylinder and accounts for the parabolic z-trajectory. - Non-airborne, translating balls delegate to the same 2D detection routine as - :func:`get_next_ball_pocket_2d_event`. - """ - if not shot.table.has_pockets: - return null_event(np.inf) - - cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) - - for ball in shot.balls.values(): - state = ball.state - params = ball.params - - for pocket in shot.table.pockets.values(): - obj_ids = (ball.id, pocket.id) - - if obj_ids in cache: - continue - - if state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue - - if state.s == const.airborne: - dtau_E = ball_pocket_collision_time_airborne( - rvw=state.rvw, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - g=params.g, - R=params.R, - ) - else: - dtau_E = ball_pocket_collision_time( - rvw=state.rvw, - s=state.s, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - cache[obj_ids] = shot.t + dtau_E - - ball_id, pocket_id = min(cache, key=lambda k: cache[k]) - - return ball_pocket_collision( - ball=shot.balls[ball_id], - pocket=shot.table.pockets[pocket_id], - time=cache[(ball_id, pocket_id)], - ) diff --git a/pooltool/evolution/event_based/detect/detector.py b/pooltool/evolution/event_based/detect/detector.py index faf4072e..e051be20 100644 --- a/pooltool/evolution/event_based/detect/detector.py +++ b/pooltool/evolution/event_based/detect/detector.py @@ -17,8 +17,7 @@ get_next_ball_linear_cushion_3d_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( - get_next_ball_pocket_2d_event, - get_next_ball_pocket_3d_event, + get_next_ball_pocket_event, ) from pooltool.evolution.event_based.detect.ball_table import ( get_next_ball_table_event, @@ -161,7 +160,7 @@ def get_next_event( candidates.append( get_next_ball_linear_cushion_3d_event(shot, collision_cache) ) - candidates.append(get_next_ball_pocket_3d_event(shot, collision_cache)) + candidates.append(get_next_ball_pocket_event(shot, collision_cache)) candidates.append(get_next_ball_table_event(shot, collision_cache)) else: candidates.append(get_next_ball_ball_2d_event(shot, collision_cache)) @@ -171,7 +170,7 @@ def get_next_event( candidates.append( get_next_ball_linear_cushion_2d_event(shot, collision_cache) ) - candidates.append(get_next_ball_pocket_2d_event(shot, collision_cache)) + candidates.append(get_next_ball_pocket_event(shot, collision_cache)) min_time = min(event.time for event in candidates) diff --git a/tests/evolution/event_based/test_ball_pocket.py b/tests/evolution/event_based/test_ball_pocket.py index 9720a521..7dd5ee53 100644 --- a/tests/evolution/event_based/test_ball_pocket.py +++ b/tests/evolution/event_based/test_ball_pocket.py @@ -6,8 +6,8 @@ from pooltool.events import EventType, filter_type from pooltool.evolution.event_based.cache import CollisionCache from pooltool.evolution.event_based.detect.ball_pocket import ( - ball_pocket_collision_time_airborne, - get_next_ball_pocket_3d_event, + ball_pocket_collision_time_if_airborne, + get_next_ball_pocket_event, ) from pooltool.evolution.event_based.simulate import simulate from pooltool.objects.ball.datatypes import Ball @@ -35,7 +35,7 @@ def test_vertical_drop_into_pocket_center(): a, b, r = 0.5, 0.5, 0.05 rvw = _airborne_rvw(0.5, 0.5, R_DEFAULT + 0.1) - t = ball_pocket_collision_time_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) + t = ball_pocket_collision_time_if_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) airborne_time = get_airborne_time(rvw, R_DEFAULT, G_DEFAULT) assert t == pytest.approx(airborne_time - const.EPS) @@ -51,7 +51,7 @@ def test_fast_low_traverse_returns_cylinder_midpoint(): a, b, r = 0.0, 0.0, 0.05 rvw = _airborne_rvw(-1.0, 0.0, 1.2 * R_DEFAULT, vx=100.0) - t = ball_pocket_collision_time_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) + t = ball_pocket_collision_time_if_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) # Trajectory crosses cylinder at t=0.0095 (influx) and t=0.0105 (outflux). assert t == pytest.approx(0.01, abs=1e-6) @@ -62,7 +62,7 @@ def test_fly_over_returns_inf(): a, b, r = 0.0, 0.0, 0.05 rvw = _airborne_rvw(-1.0, 0.0, 0.5, vx=1.0) - t = ball_pocket_collision_time_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) + t = ball_pocket_collision_time_if_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) assert t == np.inf @@ -96,5 +96,5 @@ def test_detector_skips_when_no_pockets(): shot = System(cue=Cue(cue_ball_id="cue"), table=table, balls=(ball,)) - event = get_next_ball_pocket_3d_event(shot, CollisionCache()) + event = get_next_ball_pocket_event(shot, CollisionCache()) assert event.time == np.inf From 83fa00eb5c46870b3c0a55e36f659ab42a482c30 Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 23:15:11 -0700 Subject: [PATCH 4/7] Simplify detection 2D/3D splittig logic --- .../evolution/event_based/detect/__init__.py | 12 +- .../evolution/event_based/detect/ball_ball.py | 46 +------ .../event_based/detect/ball_cushion.py | 126 +++++++----------- .../evolution/event_based/detect/detector.py | 23 +--- 4 files changed, 60 insertions(+), 147 deletions(-) diff --git a/pooltool/evolution/event_based/detect/__init__.py b/pooltool/evolution/event_based/detect/__init__.py index b582406b..cf0232ab 100644 --- a/pooltool/evolution/event_based/detect/__init__.py +++ b/pooltool/evolution/event_based/detect/__init__.py @@ -3,10 +3,8 @@ get_next_ball_ball_3d_event, ) from pooltool.evolution.event_based.detect.ball_cushion import ( - get_next_ball_circular_cushion_2d_event, - get_next_ball_circular_cushion_3d_event, - get_next_ball_linear_cushion_2d_event, - get_next_ball_linear_cushion_3d_event, + get_next_ball_circular_cushion_event, + get_next_ball_linear_cushion_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( get_next_ball_pocket_event, @@ -23,10 +21,8 @@ "EventDetector", "get_next_ball_ball_2d_event", "get_next_ball_ball_3d_event", - "get_next_ball_circular_cushion_2d_event", - "get_next_ball_circular_cushion_3d_event", - "get_next_ball_linear_cushion_2d_event", - "get_next_ball_linear_cushion_3d_event", + "get_next_ball_circular_cushion_event", + "get_next_ball_linear_cushion_event", "get_next_ball_pocket_event", "get_next_ball_table_event", "get_next_stick_ball_event", diff --git a/pooltool/evolution/event_based/detect/ball_ball.py b/pooltool/evolution/event_based/detect/ball_ball.py index 1509360b..99902120 100644 --- a/pooltool/evolution/event_based/detect/ball_ball.py +++ b/pooltool/evolution/event_based/detect/ball_ball.py @@ -17,7 +17,7 @@ @jit(nopython=True, cache=const.use_numba_cache) -def ball_ball_collision_coeffs( +def ball_ball_collision_time( rvw1: NDArray[np.float64], rvw2: NDArray[np.float64], s1: int, @@ -29,12 +29,8 @@ def ball_ball_collision_coeffs( g1: float, g2: float, R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-ball collision time - - (just-in-time compiled) - """ - +) -> float: + """Get the time until collision between 2 balls.""" c1x, c1y = rvw1[0, 0], rvw1[0, 1] c2x, c2y = rvw2[0, 0], rvw2[0, 1] @@ -82,41 +78,7 @@ def ball_ball_collision_coeffs( d = 2 * Bx * Cx + 2 * By * Cy e = Cx * Cx + Cy * Cy - 4 * R * R - return a, b, c, d, e - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_ball_collision_time( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - mu1: float, - mu2: float, - m1: float, - m2: float, - g1: float, - g2: float, - R: float, -) -> float: - """Get the time until collision between 2 balls.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_ball_collision_coeffs( - rvw1, - rvw2, - s1, - s2, - mu1, - mu2, - m1, - m2, - g1, - g2, - R, - ) - ) - ) + return get_real_positive_smallest_root(quartic.solve(a, b, c, d, e)) def get_next_ball_ball_2d_event(shot: System, collision_cache: CollisionCache) -> Event: diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py index 08ef6cc5..ae0980bc 100644 --- a/pooltool/evolution/event_based/detect/ball_cushion.py +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -21,7 +21,7 @@ @jit(nopython=True, cache=const.use_numba_cache) -def ball_linear_cushion_collision_time( +def ball_vertical_plane_collision_time( rvw: NDArray[np.float64], s: int, lx: float, @@ -35,9 +35,10 @@ def ball_linear_cushion_collision_time( g: float, R: float, ) -> float: - """Get time until collision between ball and linear cushion segment + """Get time until collision between a ball and a vertical plane. - (just-in-time compiled) + For ball trajectories limited to the playing surface, this suffices for + detecting ball collisions with linear cushion segments. """ if s == const.spinning or s == const.pocketed or s == const.stationary: return np.inf @@ -98,7 +99,7 @@ def ball_linear_cushion_collision_time( @jit(nopython=True, cache=const.use_numba_cache) -def ball_circular_cushion_collision_coeffs( +def ball_vertical_cyclinder_collision_time( rvw: NDArray[np.float64], s: int, a: float, @@ -108,14 +109,15 @@ def ball_circular_cushion_collision_coeffs( m: float, g: float, R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-circular-cushion collision time +) -> float: + """Get the time until collision between a ball and a vertical cylinder. - (just-in-time compiled) + For ball trajectories limited to the playing surface, this suffices for + detecting ball collisions with circular cushion segments. """ if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf, np.inf, np.inf, np.inf, np.inf + return np.inf phi = ptmath.angle(rvw[1]) v = ptmath.norm3d(rvw[1]) @@ -139,40 +141,10 @@ def ball_circular_cushion_collision_coeffs( cx * a + cy * b ) - return A, B, C, D, E + return get_real_positive_smallest_root(quartic.solve(A, B, C, D, E)) -@jit(nopython=True, cache=const.use_numba_cache) -def ball_circular_cushion_collision_time( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get the time until collision between a ball and a circular cushion segment.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_circular_cushion_collision_coeffs( - rvw, - s, - a, - b, - r, - mu, - m, - g, - R, - ) - ) - ) - - -def get_next_ball_linear_cushion_2d_event( +def get_next_ball_linear_cushion_event( shot: System, collision_cache: CollisionCache ) -> Event: """Detect the next ball-vs-linear-cushion collision in 2D mode.""" @@ -195,20 +167,24 @@ def get_next_ball_linear_cushion_2d_event( cache[obj_ids] = np.inf continue - dtau_E = ball_linear_cushion_collision_time( - rvw=state.rvw, - s=state.s, - lx=cushion.lx, - ly=cushion.ly, - l0=cushion.l0, - p1=cushion.p1, - p2=cushion.p2, - direction=cushion.direction, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) + if ball.state.s == const.airborne: + # TODO + dtau_E = np.inf + else: + dtau_E = ball_vertical_plane_collision_time( + rvw=state.rvw, + s=state.s, + lx=cushion.lx, + ly=cushion.ly, + l0=cushion.l0, + p1=cushion.p1, + p2=cushion.p2, + direction=cushion.direction, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) cache[obj_ids] = shot.t + dtau_E @@ -221,14 +197,7 @@ def get_next_ball_linear_cushion_2d_event( ) -def get_next_ball_linear_cushion_3d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - """3D ball-linear-cushion detection — not vendored yet; emits no event.""" - return null_event(np.inf) - - -def get_next_ball_circular_cushion_2d_event( +def get_next_ball_circular_cushion_event( shot: System, collision_cache: CollisionCache ) -> Event: """Detect the next ball-vs-circular-cushion collision in 2D mode.""" @@ -251,17 +220,22 @@ def get_next_ball_circular_cushion_2d_event( cache[obj_ids] = np.inf continue - dtau_E = ball_circular_cushion_collision_time( - rvw=state.rvw, - s=state.s, - a=cushion.a, - b=cushion.b, - r=cushion.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) + if ball.state.s == const.airborne: + # TODO + dtau_E = np.inf + else: + dtau_E = ball_vertical_cyclinder_collision_time( + rvw=state.rvw, + s=state.s, + a=cushion.a, + b=cushion.b, + r=cushion.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + cache[obj_ids] = shot.t + dtau_E ball_id, cushion_id = min(cache, key=lambda k: cache[k]) @@ -271,9 +245,3 @@ def get_next_ball_circular_cushion_2d_event( cushion=shot.table.cushion_segments.circular[cushion_id], time=cache[(ball_id, cushion_id)], ) - - -def get_next_ball_circular_cushion_3d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - return null_event(np.inf) diff --git a/pooltool/evolution/event_based/detect/detector.py b/pooltool/evolution/event_based/detect/detector.py index e051be20..27cd68d1 100644 --- a/pooltool/evolution/event_based/detect/detector.py +++ b/pooltool/evolution/event_based/detect/detector.py @@ -11,10 +11,8 @@ get_next_ball_ball_3d_event, ) from pooltool.evolution.event_based.detect.ball_cushion import ( - get_next_ball_circular_cushion_2d_event, - get_next_ball_circular_cushion_3d_event, - get_next_ball_linear_cushion_2d_event, - get_next_ball_linear_cushion_3d_event, + get_next_ball_circular_cushion_event, + get_next_ball_linear_cushion_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( get_next_ball_pocket_event, @@ -151,26 +149,15 @@ def get_next_event( candidates.append(get_next_stick_ball_event(shot, collision_cache)) candidates.append(transition_cache.get_next()) + candidates.append(get_next_ball_linear_cushion_event(shot, collision_cache)) + candidates.append(get_next_ball_circular_cushion_event(shot, collision_cache)) + candidates.append(get_next_ball_pocket_event(shot, collision_cache)) if self.is_3d: candidates.append(get_next_ball_ball_3d_event(shot, collision_cache)) - candidates.append( - get_next_ball_circular_cushion_3d_event(shot, collision_cache) - ) - candidates.append( - get_next_ball_linear_cushion_3d_event(shot, collision_cache) - ) - candidates.append(get_next_ball_pocket_event(shot, collision_cache)) candidates.append(get_next_ball_table_event(shot, collision_cache)) else: candidates.append(get_next_ball_ball_2d_event(shot, collision_cache)) - candidates.append( - get_next_ball_circular_cushion_2d_event(shot, collision_cache) - ) - candidates.append( - get_next_ball_linear_cushion_2d_event(shot, collision_cache) - ) - candidates.append(get_next_ball_pocket_event(shot, collision_cache)) min_time = min(event.time for event in candidates) From c12f2414bd0a5d7f4fbad9ae13db0b38dd86fd99 Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 23:23:35 -0700 Subject: [PATCH 5/7] Typo and additional docstring clarification --- pooltool/evolution/event_based/detect/ball_cushion.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py index ae0980bc..9f673745 100644 --- a/pooltool/evolution/event_based/detect/ball_cushion.py +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -39,6 +39,9 @@ def ball_vertical_plane_collision_time( For ball trajectories limited to the playing surface, this suffices for detecting ball collisions with linear cushion segments. + + Note: + - This is broken for airborne balls. """ if s == const.spinning or s == const.pocketed or s == const.stationary: return np.inf @@ -99,7 +102,7 @@ def ball_vertical_plane_collision_time( @jit(nopython=True, cache=const.use_numba_cache) -def ball_vertical_cyclinder_collision_time( +def ball_vertical_cylinder_collision_time( rvw: NDArray[np.float64], s: int, a: float, @@ -114,6 +117,9 @@ def ball_vertical_cyclinder_collision_time( For ball trajectories limited to the playing surface, this suffices for detecting ball collisions with circular cushion segments. + + Note: + - This is broken for airborne balls. """ if s == const.spinning or s == const.pocketed or s == const.stationary: @@ -224,7 +230,7 @@ def get_next_ball_circular_cushion_event( # TODO dtau_E = np.inf else: - dtau_E = ball_vertical_cyclinder_collision_time( + dtau_E = ball_vertical_cylinder_collision_time( rvw=state.rvw, s=state.s, a=cushion.a, From 9a505952ae5a1217e7010e924fc2b23df04cce7b Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 23:32:16 -0700 Subject: [PATCH 6/7] fix badge --- README.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/README.md b/README.md index 9edb78f2..06305e68 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/ekiefl/pooltool/test.yml) ![PyPI - Version](https://img.shields.io/pypi/v/pooltool-billiards) -![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pooltool-billiards) +[![Python Version](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Fekiefl%2Fpooltool%2Fmain%2Fpyproject.toml)](https://pypi.org/project/pooltool-billiards/) [![codecov](https://codecov.io/gh/ekiefl/pooltool/graph/badge.svg?flag=service-no-ani)](https://codecov.io/gh/ekiefl/pooltool) [![Discord](https://img.shields.io/badge/Discord-Join%20Server-7289da?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/8Y8qUgzZhz) From 7b2985415505a3b9f1b4415ea93d15d9bc0992ff Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Tue, 19 May 2026 23:36:58 -0700 Subject: [PATCH 7/7] Fix stale docstring xref to renamed airborne helper --- sandbox/airborne_demos.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sandbox/airborne_demos.py b/sandbox/airborne_demos.py index 70a8147e..152a29c1 100644 --- a/sandbox/airborne_demos.py +++ b/sandbox/airborne_demos.py @@ -97,7 +97,7 @@ def airborne_pocket_collision() -> System: Exercises both Strategy 1 (landing directly inside the pocket) and Strategy 2 (xy trajectory crossing the pocket cylinder mid-fall) in - :func:`pooltool.evolution.event_based.detect.ball_pocket.ball_pocket_collision_time_airborne`. + :func:`pooltool.evolution.event_based.detect.ball_pocket.ball_pocket_collision_time_if_airborne`. """ ball = Ball.create("cue") scale = 0.18