diff --git a/README.md b/README.md index 9edb78f2..06305e68 100644 --- a/README.md +++ b/README.md @@ -6,7 +6,7 @@ ![GitHub Actions Workflow Status](https://img.shields.io/github/actions/workflow/status/ekiefl/pooltool/test.yml) ![PyPI - Version](https://img.shields.io/pypi/v/pooltool-billiards) -![PyPI - Python Version](https://img.shields.io/pypi/pyversions/pooltool-billiards) +[![Python Version](https://img.shields.io/python/required-version-toml?tomlFilePath=https%3A%2F%2Fraw.githubusercontent.com%2Fekiefl%2Fpooltool%2Fmain%2Fpyproject.toml)](https://pypi.org/project/pooltool-billiards/) [![codecov](https://codecov.io/gh/ekiefl/pooltool/graph/badge.svg?flag=service-no-ani)](https://codecov.io/gh/ekiefl/pooltool) [![Discord](https://img.shields.io/badge/Discord-Join%20Server-7289da?style=for-the-badge&logo=discord&logoColor=white)](https://discord.gg/8Y8qUgzZhz) diff --git a/docs/include_exclude.py b/docs/include_exclude.py index b6f3f514..98ca4012 100644 --- a/docs/include_exclude.py +++ b/docs/include_exclude.py @@ -29,7 +29,6 @@ "pooltool.ruleset.snooker", # API: pooltool.physics "pooltool.physics.evolve", - "pooltool.physics.motion", "pooltool.physics.resolve", ], "module": [ diff --git a/pooltool/evolution/event_based/detect/__init__.py b/pooltool/evolution/event_based/detect/__init__.py index c645826a..cf0232ab 100644 --- a/pooltool/evolution/event_based/detect/__init__.py +++ b/pooltool/evolution/event_based/detect/__init__.py @@ -3,14 +3,11 @@ get_next_ball_ball_3d_event, ) from pooltool.evolution.event_based.detect.ball_cushion import ( - get_next_ball_circular_cushion_2d_event, - get_next_ball_circular_cushion_3d_event, - get_next_ball_linear_cushion_2d_event, - get_next_ball_linear_cushion_3d_event, + get_next_ball_circular_cushion_event, + get_next_ball_linear_cushion_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( - get_next_ball_pocket_2d_event, - get_next_ball_pocket_3d_event, + get_next_ball_pocket_event, ) from pooltool.evolution.event_based.detect.ball_table import ( get_next_ball_table_event, @@ -24,12 +21,9 @@ "EventDetector", "get_next_ball_ball_2d_event", "get_next_ball_ball_3d_event", - "get_next_ball_circular_cushion_2d_event", - "get_next_ball_circular_cushion_3d_event", - "get_next_ball_linear_cushion_2d_event", - "get_next_ball_linear_cushion_3d_event", - "get_next_ball_pocket_2d_event", - "get_next_ball_pocket_3d_event", + "get_next_ball_circular_cushion_event", + "get_next_ball_linear_cushion_event", + "get_next_ball_pocket_event", "get_next_ball_table_event", "get_next_stick_ball_event", ] diff --git a/pooltool/evolution/event_based/detect/ball_ball.py b/pooltool/evolution/event_based/detect/ball_ball.py index 1796f7ea..99902120 100644 --- a/pooltool/evolution/event_based/detect/ball_ball.py +++ b/pooltool/evolution/event_based/detect/ball_ball.py @@ -3,15 +3,84 @@ from itertools import combinations import numpy as np +from numba import jit +from numpy.typing import NDArray import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.events import Event, EventType, ball_ball_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ball_ball_collision_time +from pooltool.physics.utils import get_u_vec +from pooltool.ptmath.roots import quartic +from pooltool.ptmath.roots.core import get_real_positive_smallest_root from pooltool.system.datatypes import System +@jit(nopython=True, cache=const.use_numba_cache) +def ball_ball_collision_time( + rvw1: NDArray[np.float64], + rvw2: NDArray[np.float64], + s1: int, + s2: int, + mu1: float, + mu2: float, + m1: float, + m2: float, + g1: float, + g2: float, + R: float, +) -> float: + """Get the time until collision between 2 balls.""" + c1x, c1y = rvw1[0, 0], rvw1[0, 1] + c2x, c2y = rvw2[0, 0], rvw2[0, 1] + + if s1 == const.spinning or s1 == const.pocketed or s1 == const.stationary: + a1x, a1y, b1x, b1y = 0, 0, 0, 0 + else: + phi1 = ptmath.angle(rvw1[1]) + v1 = ptmath.norm3d(rvw1[1]) + + u1 = get_u_vec(rvw1, R, phi1, s1) + + K1 = -0.5 * mu1 * g1 + cos_phi1 = np.cos(phi1) + sin_phi1 = np.sin(phi1) + + a1x = K1 * (u1[0] * cos_phi1 - u1[1] * sin_phi1) + a1y = K1 * (u1[0] * sin_phi1 + u1[1] * cos_phi1) + b1x = v1 * cos_phi1 + b1y = v1 * sin_phi1 + + if s2 == const.spinning or s2 == const.pocketed or s2 == const.stationary: + a2x, a2y, b2x, b2y = 0.0, 0.0, 0.0, 0.0 + else: + phi2 = ptmath.angle(rvw2[1]) + v2 = ptmath.norm3d(rvw2[1]) + + u2 = get_u_vec(rvw2, R, phi2, s2) + + K2 = -0.5 * mu2 * g2 + cos_phi2 = np.cos(phi2) + sin_phi2 = np.sin(phi2) + + a2x = K2 * (u2[0] * cos_phi2 - u2[1] * sin_phi2) + a2y = K2 * (u2[0] * sin_phi2 + u2[1] * cos_phi2) + b2x = v2 * cos_phi2 + b2y = v2 * sin_phi2 + + Ax, Ay = a2x - a1x, a2y - a1y + Bx, By = b2x - b1x, b2y - b1y + Cx, Cy = c2x - c1x, c2y - c1y + + a = Ax * Ax + Ay * Ay + b = 2 * Ax * Bx + 2 * Ay * By + c = Bx * Bx + 2 * Ax * Cx + 2 * Ay * Cy + By * By + d = 2 * Bx * Cx + 2 * By * Cy + e = Cx * Cx + Cy * Cy - 4 * R * R + + return get_real_positive_smallest_root(quartic.solve(a, b, c, d, e)) + + def get_next_ball_ball_2d_event(shot: System, collision_cache: CollisionCache) -> Event: """Detect the next ball-ball collision in 2D mode.""" cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py index 6eaf494e..9f673745 100644 --- a/pooltool/evolution/event_based/detect/ball_cushion.py +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -1,8 +1,12 @@ from __future__ import annotations import numpy as np +from numba import jit +from numpy.typing import NDArray import pooltool.constants as const +import pooltool.physics.evolve as evolve +import pooltool.ptmath as ptmath from pooltool.events import ( Event, EventType, @@ -11,14 +15,142 @@ null_event, ) from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ( - ball_circular_cushion_collision_time, - ball_linear_cushion_collision_time, -) +from pooltool.physics.utils import get_u_vec +from pooltool.ptmath.roots import get_real_positive_smallest_root, quartic from pooltool.system.datatypes import System -def get_next_ball_linear_cushion_2d_event( +@jit(nopython=True, cache=const.use_numba_cache) +def ball_vertical_plane_collision_time( + rvw: NDArray[np.float64], + s: int, + lx: float, + ly: float, + l0: float, + p1: NDArray[np.float64], + p2: NDArray[np.float64], + direction: int, + mu: float, + m: float, + g: float, + R: float, +) -> float: + """Get time until collision between a ball and a vertical plane. + + For ball trajectories limited to the playing surface, this suffices for + detecting ball collisions with linear cushion segments. + + Note: + - This is broken for airborne balls. + """ + if s == const.spinning or s == const.pocketed or s == const.stationary: + return np.inf + + phi = ptmath.angle(rvw[1]) + v = ptmath.norm3d(rvw[1]) + + u = get_u_vec(rvw, R, phi, s) + + K = -0.5 * mu * g + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + + ax = K * (u[0] * cos_phi - u[1] * sin_phi) + ay = K * (u[0] * sin_phi + u[1] * cos_phi) + bx, by = v * cos_phi, v * sin_phi + cx, cy = rvw[0, 0], rvw[0, 1] + + A = lx * ax + ly * ay + B = lx * bx + ly * by + + if direction == 0: + C = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) + root1, root2 = ptmath.roots.quadratic.solve(A, B, C) + roots = [root1, root2] + elif direction == 1: + C = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) + root1, root2 = ptmath.roots.quadratic.solve(A, B, C) + roots = [root1, root2] + else: + C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) + C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) + root1, root2 = ptmath.roots.quadratic.solve(A, B, C1) + root3, root4 = ptmath.roots.quadratic.solve(A, B, C2) + roots = [root1, root2, root3, root4] + + min_time = np.inf + for root in roots: + if np.isnan(root): + continue + + if np.abs(root.imag) > const.EPS: + continue + + if root.real <= const.EPS: + continue + + rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real) + s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1) + + if not (0 <= s_score <= 1): + continue + + if root.real < min_time: + min_time = root.real + + return min_time + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_vertical_cylinder_collision_time( + rvw: NDArray[np.float64], + s: int, + a: float, + b: float, + r: float, + mu: float, + m: float, + g: float, + R: float, +) -> float: + """Get the time until collision between a ball and a vertical cylinder. + + For ball trajectories limited to the playing surface, this suffices for + detecting ball collisions with circular cushion segments. + + Note: + - This is broken for airborne balls. + """ + + if s == const.spinning or s == const.pocketed or s == const.stationary: + return np.inf + + phi = ptmath.angle(rvw[1]) + v = ptmath.norm3d(rvw[1]) + + u = get_u_vec(rvw, R, phi, s) + + K = -0.5 * mu * g + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + + ax = K * (u[0] * cos_phi - u[1] * sin_phi) + ay = K * (u[0] * sin_phi + u[1] * cos_phi) + bx, by = v * cos_phi, v * sin_phi + cx, cy = rvw[0, 0], rvw[0, 1] + + A = 0.5 * (ax * ax + ay * ay) + B = ax * bx + ay * by + C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) + D = bx * (cx - a) + by * (cy - b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - (r + R) * (r + R)) - ( + cx * a + cy * b + ) + + return get_real_positive_smallest_root(quartic.solve(A, B, C, D, E)) + + +def get_next_ball_linear_cushion_event( shot: System, collision_cache: CollisionCache ) -> Event: """Detect the next ball-vs-linear-cushion collision in 2D mode.""" @@ -41,20 +173,24 @@ def get_next_ball_linear_cushion_2d_event( cache[obj_ids] = np.inf continue - dtau_E = ball_linear_cushion_collision_time( - rvw=state.rvw, - s=state.s, - lx=cushion.lx, - ly=cushion.ly, - l0=cushion.l0, - p1=cushion.p1, - p2=cushion.p2, - direction=cushion.direction, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) + if ball.state.s == const.airborne: + # TODO + dtau_E = np.inf + else: + dtau_E = ball_vertical_plane_collision_time( + rvw=state.rvw, + s=state.s, + lx=cushion.lx, + ly=cushion.ly, + l0=cushion.l0, + p1=cushion.p1, + p2=cushion.p2, + direction=cushion.direction, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) cache[obj_ids] = shot.t + dtau_E @@ -67,14 +203,7 @@ def get_next_ball_linear_cushion_2d_event( ) -def get_next_ball_linear_cushion_3d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - """3D ball-linear-cushion detection — not vendored yet; emits no event.""" - return null_event(np.inf) - - -def get_next_ball_circular_cushion_2d_event( +def get_next_ball_circular_cushion_event( shot: System, collision_cache: CollisionCache ) -> Event: """Detect the next ball-vs-circular-cushion collision in 2D mode.""" @@ -97,17 +226,22 @@ def get_next_ball_circular_cushion_2d_event( cache[obj_ids] = np.inf continue - dtau_E = ball_circular_cushion_collision_time( - rvw=state.rvw, - s=state.s, - a=cushion.a, - b=cushion.b, - r=cushion.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) + if ball.state.s == const.airborne: + # TODO + dtau_E = np.inf + else: + dtau_E = ball_vertical_cylinder_collision_time( + rvw=state.rvw, + s=state.s, + a=cushion.a, + b=cushion.b, + r=cushion.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + cache[obj_ids] = shot.t + dtau_E ball_id, cushion_id = min(cache, key=lambda k: cache[k]) @@ -117,9 +251,3 @@ def get_next_ball_circular_cushion_2d_event( cushion=shot.table.cushion_segments.circular[cushion_id], time=cache[(ball_id, cushion_id)], ) - - -def get_next_ball_circular_cushion_3d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - return null_event(np.inf) diff --git a/pooltool/evolution/event_based/detect/ball_pocket.py b/pooltool/evolution/event_based/detect/ball_pocket.py index 198f095b..cf70444a 100644 --- a/pooltool/evolution/event_based/detect/ball_pocket.py +++ b/pooltool/evolution/event_based/detect/ball_pocket.py @@ -1,72 +1,164 @@ from __future__ import annotations import numpy as np +from numba import jit +from numpy.typing import NDArray import pooltool.constants as const +import pooltool.ptmath as ptmath from pooltool.events import Event, EventType, ball_pocket_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ( - ball_pocket_collision_time, - ball_pocket_collision_time_airborne, -) +from pooltool.physics.utils import get_airborne_time, get_u_vec +from pooltool.ptmath.roots import get_real_positive_smallest_root, quadratic, quartic from pooltool.system.datatypes import System -def get_next_ball_pocket_2d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - """Detect the next ball-pocket collision in 2D mode.""" - if not shot.table.has_pockets: - return null_event(np.inf) +@jit(nopython=True, cache=const.use_numba_cache) +def ball_pocket_collision_time( + rvw: NDArray[np.float64], + s: int, + a: float, + b: float, + r: float, + mu: float, + m: float, + g: float, + R: float, +) -> float: + """Get the time until collision between a ball and a pocket. + + Contains branching logic depending on whether the ball is airborne. + """ - cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) + if s == const.spinning or s == const.pocketed or s == const.stationary: + return np.inf + + if s == const.airborne: + # Special treatment if ball is airborne + return ball_pocket_collision_time_if_airborne(rvw, a, b, r, g, R) + + phi = ptmath.angle(rvw[1]) + v = ptmath.norm3d(rvw[1]) + + u = get_u_vec(rvw, R, phi, s) + + K = -0.5 * mu * g + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + + ax = K * (u[0] * cos_phi - u[1] * sin_phi) + ay = K * (u[0] * sin_phi + u[1] * cos_phi) + bx, by = v * cos_phi, v * sin_phi + cx, cy = rvw[0, 0], rvw[0, 1] + + A = 0.5 * (ax * ax + ay * ay) + B = ax * bx + ay * by + C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) + D = bx * (cx - a) + by * (cy - b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) + + return get_real_positive_smallest_root(quartic.solve(A, B, C, D, E)) + + +@jit(nopython=True, cache=const.use_numba_cache) +def ball_pocket_collision_time_if_airborne( + rvw: NDArray[np.float64], + a: float, + b: float, + r: float, + g: float, + R: float, +) -> float: + """Determine the ball-pocket collision time for an airborne ball. + + The behavior is somewhat complicated. Here is the procedure. + + Strategy 1: The xy-coordinates of where the ball lands are calculated. If that falls + within the pocket circle, a collision is returned. The collision time is chosen to + be just less than the collision time for the table collision, to guarantee temporal + precedence over the table collision. + + Strategy 2: Otherwise, the influx and outflux collision times are calculated between + the ball center and a vertical cylinder that extends from the pocket's circle. + Influx collision refers to the collision with the outside of the cylinder's wall. + The outflux collision refers to the collision with the inside of the cylinder's wall + and occurs later in time. Since there is no deceleration in the xy-plane for an + airborne ball, an outflux collision is expected, meaning we expect 2 finite roots. + (This is only violated if the ball starts inside the cylinder, which results in at + most an outflux collision). The strategy is to see what the ball height is at the + time of the influx collision (``h0``) and the outflux collision (``hf``), because + from these we can determine whether or not the ball is considered to enter the + pocket. The following logic is used: + + - ``h0 < R``: The ball passes through the playing surface plane before + intersecting the pocket cylinder, guaranteeing that a ball-table collision + occurs. Infinity is returned. + - ``hf <= (7/5)*R``: If the outflux height is less than ``(7/5)*R``, the ball + is considered to be pocketed. This threshold height implicitly models the + fact that high velocity balls that are slightly airborne collide with table + geometry at the back of the pocket, ricocheting the ball into the pocket. + The average of the influx and outflux collision times is returned. + - ``hf > (7/5)*R``: The ball is considered to fly over the pocket. Infinity is + returned. + """ + phi = ptmath.angle(rvw[1]) + v = ptmath.norm2d(rvw[1]) + cos_phi = np.cos(phi) + sin_phi = np.sin(phi) + bx, by = v * cos_phi, v * sin_phi - for ball in shot.balls.values(): - state = ball.state - params = ball.params + # Strategy 1: does the ball land inside the pocket circle? + airborne_time = get_airborne_time(rvw, R, g) + x = rvw[0, 0] + bx * airborne_time + y = rvw[0, 1] + by * airborne_time - for pocket in shot.table.pockets.values(): - obj_ids = (ball.id, pocket.id) + if (x - a) ** 2 + (y - b) ** 2 < r * r: + return float(airborne_time - const.EPS) - if obj_ids in cache: - continue + # Strategy 2: does the ball's xy trajectory cross the pocket cylinder? + cx, cy = rvw[0, 0], rvw[0, 1] - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue + # These match the non-airborne quartic coefficients, after setting ax=ay=0. + C = 0.5 * (bx * bx + by * by) + D = bx * (cx - a) + by * (cy - b) + E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) - dtau_E = ball_pocket_collision_time( - rvw=state.rvw, - s=state.s, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - cache[obj_ids] = shot.t + dtau_E + roots = quadratic.solve(C, D, E) - ball_id, pocket_id = min(cache, key=lambda k: cache[k]) + atol = 1e-9 + if abs(roots[0].imag) > atol or abs(roots[1].imag) > atol: + return np.inf - return ball_pocket_collision( - ball=shot.balls[ball_id], - pocket=shot.table.pockets[pocket_id], - time=cache[(ball_id, pocket_id)], - ) + real_0 = roots[0].real + real_1 = roots[1].real + if real_0 <= real_1: + r1, r2 = real_0, real_1 + else: + r1, r2 = real_1, real_0 + if r1 < 0.0: + return np.inf -def get_next_ball_pocket_3d_event( - shot: System, collision_cache: CollisionCache -) -> Event: - """Detect the next ball-pocket collision in 3D mode. + v0z = rvw[1, 2] + z0 = rvw[0, 2] - Airborne balls use :func:`ball_pocket_collision_time_airborne`, which models the - pocket as a vertical cylinder and accounts for the parabolic z-trajectory. - Non-airborne, translating balls delegate to the same 2D detection routine as - :func:`get_next_ball_pocket_2d_event`. - """ + h0 = -0.5 * g * r1 * r1 + v0z * r1 + z0 + hf = -0.5 * g * r2 * r2 + v0z * r2 + z0 + + if h0 < R: + return np.inf + + if hf > 7.0 / 5.0 * R: + return np.inf + + return (r1 + r2) / 2.0 + + +def get_next_ball_pocket_event( + shot: System, + collision_cache: CollisionCache, +) -> Event: + """Detect the next ball-pocket collision.""" if not shot.table.has_pockets: return null_event(np.inf) @@ -86,27 +178,17 @@ def get_next_ball_pocket_3d_event( cache[obj_ids] = np.inf continue - if state.s == const.airborne: - dtau_E = ball_pocket_collision_time_airborne( - rvw=state.rvw, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - g=params.g, - R=params.R, - ) - else: - dtau_E = ball_pocket_collision_time( - rvw=state.rvw, - s=state.s, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) + dtau_E = ball_pocket_collision_time( + rvw=state.rvw, + s=state.s, + a=pocket.a, + b=pocket.b, + r=pocket.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) cache[obj_ids] = shot.t + dtau_E ball_id, pocket_id = min(cache, key=lambda k: cache[k]) diff --git a/pooltool/evolution/event_based/detect/ball_table.py b/pooltool/evolution/event_based/detect/ball_table.py index df74fb54..2ca964c1 100644 --- a/pooltool/evolution/event_based/detect/ball_table.py +++ b/pooltool/evolution/event_based/detect/ball_table.py @@ -1,11 +1,33 @@ from __future__ import annotations +import numpy as np +from numba import jit +from numpy.typing import NDArray + +import pooltool.constants as const from pooltool.events import Event, EventType, ball_table_collision from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.motion.solve import ball_table_collision_time +from pooltool.physics.utils import get_airborne_time from pooltool.system.datatypes import System +@jit(nopython=True, cache=const.use_numba_cache) +def ball_table_collision_time( + rvw: NDArray[np.float64], + s: int, + g: float, + R: float, +) -> float: + """Time until an airborne ball's bottom touches the table plane. + + Returns ``np.inf`` if the ball is not airborne (no ball-table collision can + occur for any other motion state). + """ + if s != const.airborne: + return np.inf + return get_airborne_time(rvw=rvw, R=R, g=g) + + def get_next_ball_table_event(shot: System, collision_cache: CollisionCache) -> Event: """Detect the next ball-table collision (airborne ball landing). diff --git a/pooltool/evolution/event_based/detect/detector.py b/pooltool/evolution/event_based/detect/detector.py index faf4072e..27cd68d1 100644 --- a/pooltool/evolution/event_based/detect/detector.py +++ b/pooltool/evolution/event_based/detect/detector.py @@ -11,14 +11,11 @@ get_next_ball_ball_3d_event, ) from pooltool.evolution.event_based.detect.ball_cushion import ( - get_next_ball_circular_cushion_2d_event, - get_next_ball_circular_cushion_3d_event, - get_next_ball_linear_cushion_2d_event, - get_next_ball_linear_cushion_3d_event, + get_next_ball_circular_cushion_event, + get_next_ball_linear_cushion_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( - get_next_ball_pocket_2d_event, - get_next_ball_pocket_3d_event, + get_next_ball_pocket_event, ) from pooltool.evolution.event_based.detect.ball_table import ( get_next_ball_table_event, @@ -152,26 +149,15 @@ def get_next_event( candidates.append(get_next_stick_ball_event(shot, collision_cache)) candidates.append(transition_cache.get_next()) + candidates.append(get_next_ball_linear_cushion_event(shot, collision_cache)) + candidates.append(get_next_ball_circular_cushion_event(shot, collision_cache)) + candidates.append(get_next_ball_pocket_event(shot, collision_cache)) if self.is_3d: candidates.append(get_next_ball_ball_3d_event(shot, collision_cache)) - candidates.append( - get_next_ball_circular_cushion_3d_event(shot, collision_cache) - ) - candidates.append( - get_next_ball_linear_cushion_3d_event(shot, collision_cache) - ) - candidates.append(get_next_ball_pocket_3d_event(shot, collision_cache)) candidates.append(get_next_ball_table_event(shot, collision_cache)) else: candidates.append(get_next_ball_ball_2d_event(shot, collision_cache)) - candidates.append( - get_next_ball_circular_cushion_2d_event(shot, collision_cache) - ) - candidates.append( - get_next_ball_linear_cushion_2d_event(shot, collision_cache) - ) - candidates.append(get_next_ball_pocket_2d_event(shot, collision_cache)) min_time = min(event.time for event in candidates) diff --git a/pooltool/physics/motion/__init__.py b/pooltool/physics/motion/__init__.py deleted file mode 100644 index e69de29b..00000000 diff --git a/pooltool/physics/motion/solve.py b/pooltool/physics/motion/solve.py deleted file mode 100644 index 53905724..00000000 --- a/pooltool/physics/motion/solve.py +++ /dev/null @@ -1,534 +0,0 @@ -from math import acos - -import numpy as np -from numba import jit -from numpy.typing import NDArray - -import pooltool.constants as const -import pooltool.physics.evolve as evolve -import pooltool.ptmath as ptmath -from pooltool.physics.utils import get_airborne_time, rel_velocity -from pooltool.ptmath.roots import quadratic, quartic -from pooltool.ptmath.roots.core import get_real_positive_smallest_root - - -@jit(nopython=True, cache=const.use_numba_cache) -def skip_ball_ball_collision( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - R1: float, - R2: float, -) -> bool: - if (s1 == const.spinning or s1 == const.pocketed or s1 == const.stationary) and ( - s2 == const.spinning or s2 == const.pocketed or s2 == const.stationary - ): - # Neither balls are moving. No collision. - return True - - if s1 == const.pocketed or s2 == const.pocketed: - # One of the balls is pocketed - return True - - if s1 == const.rolling and s2 == const.rolling: - # Both balls are rolling (straight line trajectories). Here I am checking - # whether both dot products face away from the line connecting the two balls. If - # so, they are guaranteed not to collide - r12 = rvw2[0] - rvw1[0] - dot1 = r12[0] * rvw1[1, 0] + r12[1] * rvw1[1, 1] + r12[2] * rvw1[1, 2] - if dot1 <= 0: - dot2 = r12[0] * rvw2[1, 0] + r12[1] * rvw2[1, 1] + r12[2] * rvw2[1, 2] - if dot2 >= 0: - return True - - if s1 == const.rolling and (s2 == const.spinning or s2 == const.stationary): - # ball1 is rolling, which guarantees a straight-line trajectory. Some - # assumptions can be made based on this fact - r12 = rvw2[0] - rvw1[0] - - # ball2 is not moving, so we can pinpoint the range of angles ball1 must be - # headed in for a collision - d = ptmath.norm3d(r12) - unit_d = r12 / d - unit_v = ptmath.unit_vector(rvw1[1]) - - # Angles are in radians - # Calculate forwards and backwards angles, e.g. 10 and 350, take the min - angle = np.arccos(np.dot(unit_d, unit_v)) - max_hit_angle = 0.5 * np.pi - acos((R1 + R2) / d) - if angle > max_hit_angle: - return True - - if s2 == const.rolling and (s1 == const.spinning or s1 == const.stationary): - # ball2 is rolling, which guarantees a straight-line trajectory. Some - # assumptions can be made based on this fact - r21 = rvw1[0] - rvw2[0] - - # ball1 is not moving, so we can pinpoint the range of angles ball2 must be - # headed in for a collision - d = ptmath.norm3d(r21) - unit_d = r21 / d - unit_v = ptmath.unit_vector(rvw2[1]) - - # Angles are in radians - # Calculate forwards and backwards angles, e.g. 10 and 350, take the min - angle = np.arccos(np.dot(unit_d, unit_v)) - max_hit_angle = 0.5 * np.pi - acos((R1 + R2) / d) - if angle > max_hit_angle: - return True - - return False - - -@jit(nopython=True, cache=const.use_numba_cache) -def get_u( - rvw: NDArray[np.float64], R: float, phi: float, s: int -) -> NDArray[np.float64]: - if s == const.rolling: - return np.array([1, 0, 0], dtype=np.float64) - - rel_vel = rel_velocity(rvw, R) - if (rel_vel == 0).all(): - return np.array([1, 0, 0], dtype=np.float64) - - return ptmath.coordinate_rotation(ptmath.unit_vector(rel_vel), -phi) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_ball_collision_coeffs( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - mu1: float, - mu2: float, - m1: float, - m2: float, - g1: float, - g2: float, - R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-ball collision time - - (just-in-time compiled) - """ - - c1x, c1y = rvw1[0, 0], rvw1[0, 1] - c2x, c2y = rvw2[0, 0], rvw2[0, 1] - - if s1 == const.spinning or s1 == const.pocketed or s1 == const.stationary: - a1x, a1y, b1x, b1y = 0, 0, 0, 0 - else: - phi1 = ptmath.angle(rvw1[1]) - v1 = ptmath.norm3d(rvw1[1]) - - u1 = get_u(rvw1, R, phi1, s1) - - K1 = -0.5 * mu1 * g1 - cos_phi1 = np.cos(phi1) - sin_phi1 = np.sin(phi1) - - a1x = K1 * (u1[0] * cos_phi1 - u1[1] * sin_phi1) - a1y = K1 * (u1[0] * sin_phi1 + u1[1] * cos_phi1) - b1x = v1 * cos_phi1 - b1y = v1 * sin_phi1 - - if s2 == const.spinning or s2 == const.pocketed or s2 == const.stationary: - a2x, a2y, b2x, b2y = 0.0, 0.0, 0.0, 0.0 - else: - phi2 = ptmath.angle(rvw2[1]) - v2 = ptmath.norm3d(rvw2[1]) - - u2 = get_u(rvw2, R, phi2, s2) - - K2 = -0.5 * mu2 * g2 - cos_phi2 = np.cos(phi2) - sin_phi2 = np.sin(phi2) - - a2x = K2 * (u2[0] * cos_phi2 - u2[1] * sin_phi2) - a2y = K2 * (u2[0] * sin_phi2 + u2[1] * cos_phi2) - b2x = v2 * cos_phi2 - b2y = v2 * sin_phi2 - - Ax, Ay = a2x - a1x, a2y - a1y - Bx, By = b2x - b1x, b2y - b1y - Cx, Cy = c2x - c1x, c2y - c1y - - a = Ax * Ax + Ay * Ay - b = 2 * Ax * Bx + 2 * Ay * By - c = Bx * Bx + 2 * Ax * Cx + 2 * Ay * Cy + By * By - d = 2 * Bx * Cx + 2 * By * Cy - e = Cx * Cx + Cy * Cy - 4 * R * R - - return a, b, c, d, e - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_ball_collision_time( - rvw1: NDArray[np.float64], - rvw2: NDArray[np.float64], - s1: int, - s2: int, - mu1: float, - mu2: float, - m1: float, - m2: float, - g1: float, - g2: float, - R: float, -) -> float: - """Get the time until collision between 2 balls.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_ball_collision_coeffs( - rvw1, - rvw2, - s1, - s2, - mu1, - mu2, - m1, - m2, - g1, - g2, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_linear_cushion_collision_time( - rvw: NDArray[np.float64], - s: int, - lx: float, - ly: float, - l0: float, - p1: NDArray[np.float64], - p2: NDArray[np.float64], - direction: int, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get time until collision between ball and linear cushion segment - - (just-in-time compiled) - """ - if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf - - phi = ptmath.angle(rvw[1]) - v = ptmath.norm3d(rvw[1]) - - u = get_u(rvw, R, phi, s) - - K = -0.5 * mu * g - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - - ax = K * (u[0] * cos_phi - u[1] * sin_phi) - ay = K * (u[0] * sin_phi + u[1] * cos_phi) - bx, by = v * cos_phi, v * sin_phi - cx, cy = rvw[0, 0], rvw[0, 1] - - A = lx * ax + ly * ay - B = lx * bx + ly * by - - if direction == 0: - C = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) - root1, root2 = ptmath.roots.quadratic.solve(A, B, C) - roots = [root1, root2] - elif direction == 1: - C = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) - root1, root2 = ptmath.roots.quadratic.solve(A, B, C) - roots = [root1, root2] - else: - C1 = l0 + lx * cx + ly * cy + R * np.sqrt(lx * lx + ly * ly) - C2 = l0 + lx * cx + ly * cy - R * np.sqrt(lx * lx + ly * ly) - root1, root2 = ptmath.roots.quadratic.solve(A, B, C1) - root3, root4 = ptmath.roots.quadratic.solve(A, B, C2) - roots = [root1, root2, root3, root4] - - min_time = np.inf - for root in roots: - if np.isnan(root): - continue - - if np.abs(root.imag) > const.EPS: - continue - - if root.real <= const.EPS: - continue - - rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real) - s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1) - - if not (0 <= s_score <= 1): - continue - - if root.real < min_time: - min_time = root.real - - return min_time - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_circular_cushion_collision_coeffs( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-circular-cushion collision time - - (just-in-time compiled) - """ - - if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf, np.inf, np.inf, np.inf, np.inf - - phi = ptmath.angle(rvw[1]) - v = ptmath.norm3d(rvw[1]) - - u = get_u(rvw, R, phi, s) - - K = -0.5 * mu * g - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - - ax = K * (u[0] * cos_phi - u[1] * sin_phi) - ay = K * (u[0] * sin_phi + u[1] * cos_phi) - bx, by = v * cos_phi, v * sin_phi - cx, cy = rvw[0, 0], rvw[0, 1] - - A = 0.5 * (ax * ax + ay * ay) - B = ax * bx + ay * by - C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) - D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a * a + b * b + cx * cx + cy * cy - (r + R) * (r + R)) - ( - cx * a + cy * b - ) - - return A, B, C, D, E - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_circular_cushion_collision_time( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get the time until collision between a ball and a circular cushion segment.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_circular_cushion_collision_coeffs( - rvw, - s, - a, - b, - r, - mu, - m, - g, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_coeffs( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> tuple[float, float, float, float, float]: - """Get quartic coeffs required to determine the ball-pocket collision time - - (just-in-time compiled) - """ - - if s == const.spinning or s == const.pocketed or s == const.stationary: - return np.inf, np.inf, np.inf, np.inf, np.inf - - phi = ptmath.angle(rvw[1]) - v = ptmath.norm3d(rvw[1]) - - u = get_u(rvw, R, phi, s) - - K = -0.5 * mu * g - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - - ax = K * (u[0] * cos_phi - u[1] * sin_phi) - ay = K * (u[0] * sin_phi + u[1] * cos_phi) - bx, by = v * cos_phi, v * sin_phi - cx, cy = rvw[0, 0], rvw[0, 1] - - A = 0.5 * (ax * ax + ay * ay) - B = ax * bx + ay * by - C = ax * (cx - a) + ay * (cy - b) + 0.5 * (bx * bx + by * by) - D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) - - return A, B, C, D, E - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_time( - rvw: NDArray[np.float64], - s: int, - a: float, - b: float, - r: float, - mu: float, - m: float, - g: float, - R: float, -) -> float: - """Get the time until collision between a ball and a pocket.""" - return get_real_positive_smallest_root( - quartic.solve( - *ball_pocket_collision_coeffs( - rvw, - s, - a, - b, - r, - mu, - m, - g, - R, - ) - ) - ) - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_pocket_collision_time_airborne( - rvw: NDArray[np.float64], - a: float, - b: float, - r: float, - g: float, - R: float, -) -> float: - """Determine the ball-pocket collision time for an airborne ball. - - The behavior is somewhat complicated. Here is the procedure. - - Strategy 1: The xy-coordinates of where the ball lands are calculated. If that falls - within the pocket circle, a collision is returned. The collision time is chosen to - be just less than the collision time for the table collision, to guarantee temporal - precedence over the table collision. - - Strategy 2: Otherwise, the influx and outflux collision times are calculated between - the ball center and a vertical cylinder that extends from the pocket's circle. - Influx collision refers to the collision with the outside of the cylinder's wall. - The outflux collision refers to the collision with the inside of the cylinder's wall - and occurs later in time. Since there is no deceleration in the xy-plane for an - airborne ball, an outflux collision is expected, meaning we expect 2 finite roots. - (This is only violated if the ball starts inside the cylinder, which results in at - most an outflux collision). The strategy is to see what the ball height is at the - time of the influx collision (``h0``) and the outflux collision (``hf``), because - from these we can determine whether or not the ball is considered to enter the - pocket. The following logic is used: - - - ``h0 < R``: The ball passes through the playing surface plane before - intersecting the pocket cylinder, guaranteeing that a ball-table collision - occurs. Infinity is returned. - - ``hf <= (7/5)*R``: If the outflux height is less than ``(7/5)*R``, the ball - is considered to be pocketed. This threshold height implicitly models the - fact that high velocity balls that are slightly airborne collide with table - geometry at the back of the pocket, ricocheting the ball into the pocket. - The average of the influx and outflux collision times is returned. - - ``hf > (7/5)*R``: The ball is considered to fly over the pocket. Infinity is - returned. - """ - phi = ptmath.angle(rvw[1]) - v = ptmath.norm2d(rvw[1]) - cos_phi = np.cos(phi) - sin_phi = np.sin(phi) - bx, by = v * cos_phi, v * sin_phi - - # Strategy 1: does the ball land inside the pocket circle? - airborne_time = get_airborne_time(rvw, R, g) - x = rvw[0, 0] + bx * airborne_time - y = rvw[0, 1] + by * airborne_time - - if (x - a) ** 2 + (y - b) ** 2 < r * r: - return float(airborne_time - const.EPS) - - # Strategy 2: does the ball's xy trajectory cross the pocket cylinder? - cx, cy = rvw[0, 0], rvw[0, 1] - - # These match the non-airborne quartic coefficients, after setting ax=ay=0. - C = 0.5 * (bx * bx + by * by) - D = bx * (cx - a) + by * (cy - b) - E = 0.5 * (a * a + b * b + cx * cx + cy * cy - r * r) - (cx * a + cy * b) - - roots = quadratic.solve(C, D, E) - - atol = 1e-9 - if abs(roots[0].imag) > atol or abs(roots[1].imag) > atol: - return np.inf - - real_0 = roots[0].real - real_1 = roots[1].real - if real_0 <= real_1: - r1, r2 = real_0, real_1 - else: - r1, r2 = real_1, real_0 - - if r1 < 0.0: - return np.inf - - v0z = rvw[1, 2] - z0 = rvw[0, 2] - - h0 = -0.5 * g * r1 * r1 + v0z * r1 + z0 - hf = -0.5 * g * r2 * r2 + v0z * r2 + z0 - - if h0 < R: - return np.inf - - if hf > 7.0 / 5.0 * R: - return np.inf - - return (r1 + r2) / 2.0 - - -@jit(nopython=True, cache=const.use_numba_cache) -def ball_table_collision_time( - rvw: NDArray[np.float64], - s: int, - g: float, - R: float, -) -> float: - """Time until an airborne ball's bottom touches the table plane. - - Returns ``np.inf`` if the ball is not airborne (no ball-table collision can - occur for any other motion state). - """ - if s != const.airborne: - return np.inf - return get_airborne_time(rvw=rvw, R=R, g=g) diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 1208f698..6ffdcd5f 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -37,15 +37,14 @@ def rel_velocity(rvw: NDArray[np.float64], R: float) -> NDArray[np.float64]: @jit(nopython=True, cache=const.use_numba_cache) def get_u_vec( - rvw: NDArray[np.float64], phi: float, R: float, s: int + rvw: NDArray[np.float64], R: float, phi: float, s: int ) -> NDArray[np.float64]: if s == const.rolling: - return np.array([1.0, 0.0, 0.0]) + return np.array([1, 0, 0], dtype=np.float64) rel_vel = rel_velocity(rvw, R) - - if (rel_vel == 0.0).all(): - return np.array([1.0, 0.0, 0.0]) + if (rel_vel == 0).all(): + return np.array([1, 0, 0], dtype=np.float64) return coordinate_rotation(unit_vector(rel_vel), -phi) diff --git a/pooltool/ptmath/roots/__init__.py b/pooltool/ptmath/roots/__init__.py index 3b07653d..7e7f65c5 100644 --- a/pooltool/ptmath/roots/__init__.py +++ b/pooltool/ptmath/roots/__init__.py @@ -1,7 +1,13 @@ import pooltool.ptmath.roots.quadratic as quadratic import pooltool.ptmath.roots.quartic as quartic +from pooltool.ptmath.roots.core import ( + get_real_positive_smallest_root, + get_real_positive_smallest_roots, +) __all__ = [ "quadratic", "quartic", + "get_real_positive_smallest_root", + "get_real_positive_smallest_roots", ] diff --git a/sandbox/airborne_demos.py b/sandbox/airborne_demos.py index a98e10d6..152a29c1 100644 --- a/sandbox/airborne_demos.py +++ b/sandbox/airborne_demos.py @@ -97,7 +97,7 @@ def airborne_pocket_collision() -> System: Exercises both Strategy 1 (landing directly inside the pocket) and Strategy 2 (xy trajectory crossing the pocket cylinder mid-fall) in - :func:`pooltool.physics.motion.solve.ball_pocket_collision_time_airborne`. + :func:`pooltool.evolution.event_based.detect.ball_pocket.ball_pocket_collision_time_if_airborne`. """ ball = Ball.create("cue") scale = 0.18 diff --git a/tests/evolution/event_based/test_ball_pocket.py b/tests/evolution/event_based/test_ball_pocket.py index b7bb33dd..7dd5ee53 100644 --- a/tests/evolution/event_based/test_ball_pocket.py +++ b/tests/evolution/event_based/test_ball_pocket.py @@ -6,14 +6,14 @@ from pooltool.events import EventType, filter_type from pooltool.evolution.event_based.cache import CollisionCache from pooltool.evolution.event_based.detect.ball_pocket import ( - get_next_ball_pocket_3d_event, + ball_pocket_collision_time_if_airborne, + get_next_ball_pocket_event, ) from pooltool.evolution.event_based.simulate import simulate from pooltool.objects.ball.datatypes import Ball from pooltool.objects.cue.datatypes import Cue from pooltool.objects.table.datatypes import Table from pooltool.objects.table.specs import TableType -from pooltool.physics.motion.solve import ball_pocket_collision_time_airborne from pooltool.physics.utils import get_airborne_time from pooltool.system.datatypes import System @@ -35,7 +35,7 @@ def test_vertical_drop_into_pocket_center(): a, b, r = 0.5, 0.5, 0.05 rvw = _airborne_rvw(0.5, 0.5, R_DEFAULT + 0.1) - t = ball_pocket_collision_time_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) + t = ball_pocket_collision_time_if_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) airborne_time = get_airborne_time(rvw, R_DEFAULT, G_DEFAULT) assert t == pytest.approx(airborne_time - const.EPS) @@ -51,7 +51,7 @@ def test_fast_low_traverse_returns_cylinder_midpoint(): a, b, r = 0.0, 0.0, 0.05 rvw = _airborne_rvw(-1.0, 0.0, 1.2 * R_DEFAULT, vx=100.0) - t = ball_pocket_collision_time_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) + t = ball_pocket_collision_time_if_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) # Trajectory crosses cylinder at t=0.0095 (influx) and t=0.0105 (outflux). assert t == pytest.approx(0.01, abs=1e-6) @@ -62,7 +62,7 @@ def test_fly_over_returns_inf(): a, b, r = 0.0, 0.0, 0.05 rvw = _airborne_rvw(-1.0, 0.0, 0.5, vx=1.0) - t = ball_pocket_collision_time_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) + t = ball_pocket_collision_time_if_airborne(rvw, a, b, r, G_DEFAULT, R_DEFAULT) assert t == np.inf @@ -96,5 +96,5 @@ def test_detector_skips_when_no_pockets(): shot = System(cue=Cue(cue_ball_id="cue"), table=table, balls=(ball,)) - event = get_next_ball_pocket_3d_event(shot, CollisionCache()) + event = get_next_ball_pocket_event(shot, CollisionCache()) assert event.time == np.inf diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index da9291f3..b1bd6e1c 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -13,11 +13,11 @@ EventDetector, get_next_ball_ball_2d_event, ) +from pooltool.evolution.event_based.detect.ball_ball import ball_ball_collision_time from pooltool.evolution.event_based.simulate import simulate from pooltool.objects import Ball, BilliardTableSpecs, Cue, Table from pooltool.objects.ball.params import BallParams from pooltool.objects.ball.sets import BallSet -from pooltool.physics.motion.solve import ball_ball_collision_time from pooltool.ptmath.roots import quadratic from pooltool.system import System from tests.evolution.event_based.test_data import TEST_DIR