From 019700e7e91678f52b105dd4164234a2ebe398cd Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Sun, 17 May 2026 16:12:05 -0700 Subject: [PATCH 1/2] Add airborne --- pooltool/constants.py | 8 +- pooltool/events/__init__.py | 2 + pooltool/events/datatypes.py | 5 + pooltool/events/factory.py | 17 +++ pooltool/events/utils.py | 1 + pooltool/physics/__init__.py | 2 + pooltool/physics/evolve/__init__.py | 26 +++++ pooltool/physics/utils.py | 28 +++++ tests/physics/evolve/test_airborne.py | 145 ++++++++++++++++++++++++++ tests/physics/test_utils.py | 71 ++++++++++++- 10 files changed, 303 insertions(+), 2 deletions(-) create mode 100644 tests/physics/evolve/test_airborne.py diff --git a/pooltool/constants.py b/pooltool/constants.py index bbc3d852..c71d475c 100644 --- a/pooltool/constants.py +++ b/pooltool/constants.py @@ -45,6 +45,11 @@ A ball with this motion state is in a pocket. """ +airborne: int = 5 +"""The airborne motion state label + +A ball with this motion state is airborne. +""" state_dict: dict[int, str] = { 0: "stationary", @@ -52,8 +57,9 @@ 2: "sliding", 3: "rolling", 4: "pocketed", + 5: "airborne", } on_table = {stationary, spinning, sliding, rolling} nontranslating = {stationary, spinning, pocketed} -energetic = {spinning, sliding, rolling} +energetic = {spinning, sliding, rolling, airborne} diff --git a/pooltool/events/__init__.py b/pooltool/events/__init__.py index 2545f958..1405b0b4 100644 --- a/pooltool/events/__init__.py +++ b/pooltool/events/__init__.py @@ -10,6 +10,7 @@ ball_circular_cushion_collision, ball_linear_cushion_collision, ball_pocket_collision, + ball_table_collision, null_event, rolling_spinning_transition, rolling_stationary_transition, @@ -41,6 +42,7 @@ "ball_circular_cushion_collision", "ball_pocket_collision", "stick_ball_collision", + "ball_table_collision", "spinning_stationary_transition", "rolling_stationary_transition", "rolling_spinning_transition", diff --git a/pooltool/events/datatypes.py b/pooltool/events/datatypes.py index 5653e3bb..77396ef3 100644 --- a/pooltool/events/datatypes.py +++ b/pooltool/events/datatypes.py @@ -35,6 +35,8 @@ class EventType(strenum.StrEnum): the *point of no return*. STICK_BALL: A cue-stick ball collision. + BALL_TABLE: + A ball collision into the table surface. SPINNING_STATIONARY: A ball transition from spinning to stationary. ROLLING_STATIONARY: @@ -51,6 +53,7 @@ class EventType(strenum.StrEnum): BALL_CIRCULAR_CUSHION = strenum.auto() BALL_POCKET = strenum.auto() STICK_BALL = strenum.auto() + BALL_TABLE = strenum.auto() SPINNING_STATIONARY = strenum.auto() ROLLING_STATIONARY = strenum.auto() ROLLING_SPINNING = strenum.auto() @@ -64,6 +67,7 @@ def is_collision(self) -> bool: EventType.BALL_LINEAR_CUSHION, EventType.BALL_POCKET, EventType.STICK_BALL, + EventType.BALL_TABLE, } def is_transition(self) -> bool: @@ -85,6 +89,7 @@ def has_ball(self) -> bool: EventType.BALL_CIRCULAR_CUSHION, EventType.BALL_POCKET, EventType.STICK_BALL, + EventType.BALL_TABLE, } or self.is_transition() ) diff --git a/pooltool/events/factory.py b/pooltool/events/factory.py index cef8db48..907975c9 100644 --- a/pooltool/events/factory.py +++ b/pooltool/events/factory.py @@ -99,6 +99,23 @@ def stick_ball_collision( ) +def ball_table_collision(ball: Ball, time: float, set_initial: bool = False) -> Event: + """Create a ball-table collision. + + Note: + - Since information about the ball-table interaction is stored exclusively in + the ball (not the table), and since the table is a large composite of + individual objects which is costly to serialize, the table is not stored as an + agent of the returned event and is therefore not accepted as an argument of + this function. + """ + return Event( + event_type=EventType.BALL_TABLE, + agents=(Agent.from_object(ball, set_initial=set_initial),), + time=time, + ) + + def spinning_stationary_transition( ball: Ball, time: float, set_initial: bool = False ) -> Event: diff --git a/pooltool/events/utils.py b/pooltool/events/utils.py index 7c3b3296..791a6244 100644 --- a/pooltool/events/utils.py +++ b/pooltool/events/utils.py @@ -6,6 +6,7 @@ EventType.BALL_CIRCULAR_CUSHION: {0}, EventType.BALL_POCKET: {0}, EventType.STICK_BALL: {1}, + EventType.BALL_TABLE: {0}, EventType.SPINNING_STATIONARY: {0}, EventType.ROLLING_STATIONARY: {0}, EventType.ROLLING_SPINNING: {0}, diff --git a/pooltool/physics/__init__.py b/pooltool/physics/__init__.py index 1a11e02e..303b6cf3 100644 --- a/pooltool/physics/__init__.py +++ b/pooltool/physics/__init__.py @@ -35,6 +35,7 @@ ball_transition_models, ) from pooltool.physics.utils import ( + get_airborne_time, get_ball_energy, get_roll_time, get_slide_time, @@ -61,6 +62,7 @@ "get_slide_time", "get_roll_time", "get_spin_time", + "get_airborne_time", "get_ball_energy", "ball_ball_models", "BallBallFrictionModel", diff --git a/pooltool/physics/evolve/__init__.py b/pooltool/physics/evolve/__init__.py index 32e0ed3c..ec1dbac7 100644 --- a/pooltool/physics/evolve/__init__.py +++ b/pooltool/physics/evolve/__init__.py @@ -47,6 +47,9 @@ def evolve_ball_motion( if state == const.stationary or state == const.pocketed: return rvw, state + if state == const.airborne: + return _evolve_airborne_state(rvw, g, t), const.airborne + if state == const.sliding: dtau_E_slide = get_slide_time(rvw, R, u_s, g) @@ -183,3 +186,26 @@ def _evolve_perpendicular_spin_state( ) -> NDArray[np.float64]: rvw[2, 2] = _evolve_perpendicular_spin_component(rvw[2, 2], R, u_sp, g, t) return rvw + + +@jit(nopython=True, cache=const.use_numba_cache) +def _evolve_airborne_state( + rvw: NDArray[np.float64], g: float, t: float +) -> NDArray[np.float64]: + """Parabolic evolution under gravity. Angular velocity is conserved.""" + if t == 0: + return rvw + + r_0, v_0, w_0 = rvw + + g_vec = np.array([0.0, 0.0, g], dtype=np.float64) + + r = r_0 + v_0 * t - 0.5 * g_vec * t**2 + v = v_0 - g_vec * t + + new_rvw = np.empty((3, 3), dtype=np.float64) + new_rvw[0, :] = r + new_rvw[1, :] = v + new_rvw[2, :] = w_0 + + return new_rvw diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 1c86dba8..97c38843 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -49,6 +49,34 @@ def get_u_vec( return coordinate_rotation(unit_vector(rel_vel), -phi) +@jit(nopython=True, cache=const.use_numba_cache) +def get_airborne_time(rvw: NDArray[np.float64], R: float, g: float) -> float: + """Time until an airborne ball's bottom touches the table plane (z = R). + + Returns ``np.inf`` if no future intersection exists (either gravity is zero, or the + discriminant of the quadratic ``-0.5 * g * t**2 + v_z * t + (z - R) = 0`` is + negative). + """ + if g == 0.0: + return np.inf + + A = -0.5 * g + B = rvw[1, 2] + C = rvw[0, 2] - R + + D = B**2 - 4 * A * C + + if D < 0: + # Only consider real roots. + return np.inf + + # This is the only possible root assuming the ball starts above the table and + # acceleration due to gravity is towards table. + t_f = -(B + np.sqrt(D)) / (2 * A) + + return t_f + + @jit(nopython=True, cache=const.use_numba_cache) def get_slide_time(rvw: NDArray[np.float64], R: float, u_s: float, g: float) -> float: if u_s == 0.0: diff --git a/tests/physics/evolve/test_airborne.py b/tests/physics/evolve/test_airborne.py new file mode 100644 index 00000000..50909884 --- /dev/null +++ b/tests/physics/evolve/test_airborne.py @@ -0,0 +1,145 @@ +import numpy as np + +from pooltool.physics.evolve import _evolve_airborne_state + + +def test_xy_velocity_conserved(): + """Test that the x- and y-components of the velocity remain unchanged as time evolves.""" + r0 = np.array([0.0, 0.0, 0.0], dtype=np.float64) + v0 = np.array([1.0, 2.0, 3.0], dtype=np.float64) + w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64) + rvw0 = np.array([r0, v0, w0], dtype=np.float64) + + g = 9.81 + t = 1.0 + + rvw = _evolve_airborne_state(rvw0.copy(), g, t) + + # Check if vx and vy are unchanged + np.testing.assert_almost_equal( + rvw[1, 0], v0[0], err_msg="X velocity changed unexpectedly." + ) + np.testing.assert_almost_equal( + rvw[1, 1], v0[1], err_msg="Y velocity changed unexpectedly." + ) + + +def test_angular_velocity_conserved(): + """Test that the angular velocity (w) is conserved and remains unchanged as time evolves.""" + r0 = np.array([0.0, 0.0, 0.0], dtype=np.float64) + v0 = np.array([1.0, 2.0, 3.0], dtype=np.float64) + w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64) + rvw0 = np.array([r0, v0, w0], dtype=np.float64) + + g = 9.81 + t = 2.0 + + rvw = _evolve_airborne_state(rvw0.copy(), g, t) + + # Check if angular velocity is unchanged + np.testing.assert_array_almost_equal( + rvw[2], w0, err_msg="Angular velocity changed unexpectedly." + ) + + +def test_xy_displacement_linear(): + """Test that the xy-displacement changes linearly with time. + + This assumes there is no air friction. + + Equations: + r_x(t) = r_x(0) + v_x(0)*t + r_y(t) = r_y(0) + v_y(0)*t + """ + r0 = np.array([0.0, 0.0, 10.0], dtype=np.float64) + v0 = np.array([3.0, 4.0, 5.0], dtype=np.float64) + w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64) + rvw0 = np.array([r0, v0, w0], dtype=np.float64) + + g = 9.81 + + for t in [0.0, 1.0, 2.0, 3.0]: + rvw = _evolve_airborne_state(rvw0.copy(), g, t) + expected_x = r0[0] + v0[0] * t + expected_y = r0[1] + v0[1] * t + np.testing.assert_almost_equal( + rvw[0, 0], expected_x, err_msg="X displacement not linear in time." + ) + np.testing.assert_almost_equal( + rvw[0, 1], expected_y, err_msg="Y displacement not linear in time." + ) + + +def test_gravity_direction(): + """Test that gravity affects the motion in the z-direction only. + + Equations: + v_z(t) = v_z(0) - g*t + r_z(t) = r_z(0) + v_z(0)*t - (1/2)*g*t^2 + """ + r0 = np.array([0.0, 0.0, 10.0], dtype=np.float64) + v0 = np.array([3.0, 4.0, 5.0], dtype=np.float64) + w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64) + rvw = np.array([r0, v0, w0], dtype=np.float64) + + g = 9.81 + t = 1.0 + + rvw_t = _evolve_airborne_state(rvw.copy(), g, t) + expected_z = r0[2] + v0[2] * t - 0.5 * g * t**2 + expected_vz = v0[2] - g * t + + np.testing.assert_almost_equal( + rvw_t[0, 2], + expected_z, + err_msg="Z displacement does not match gravity equation.", + ) + np.testing.assert_almost_equal( + rvw_t[1, 2], expected_vz, err_msg="Z velocity does not match gravity equation." + ) + + +def test_z_displacement_parabolic(): + """Test that z-displacement is parabolic. + + Equations: + z(t) = z(0) + v_z(0)*t - (g/2)*t^2 + """ + r0 = np.array([0.0, 0.0, 10.0], dtype=np.float64) + v0 = np.array([3.0, 4.0, 5.0], dtype=np.float64) + w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64) + rvw = np.array([r0, v0, w0], dtype=np.float64) + + g = 9.81 + times = np.array([0.0, 1.0, 2.0, 3.0], dtype=np.float64) + z_values = [] + + for t in times: + rvw_t = _evolve_airborne_state(rvw.copy(), g, t) + z_values.append(rvw_t[0, 2]) + + z_values = np.array(z_values) + # Fit a quadratic to the computed z values + coeffs = np.polyfit(times, z_values, 2) + # The true quadratic form is: z(t) = z0 + v0_z t - (g/2)*t^2 + # Coefficients from polyfit are in order: a*t^2 + b*t + c + # We know a should be -g/2, b should be v0_z, and c should be z0. + + np.testing.assert_almost_equal( + coeffs[0], + -g / 2, + decimal=5, + err_msg="Quadratic coefficient a does not match -g/2.", + ) + np.testing.assert_almost_equal( + coeffs[1], + v0[2], + decimal=5, + err_msg="Quadratic coefficient b does not match initial v_z.", + ) + np.testing.assert_almost_equal( + coeffs[2], + r0[2], + decimal=5, + err_msg="Quadratic coefficient c does not match initial z.", + ) diff --git a/tests/physics/test_utils.py b/tests/physics/test_utils.py index 232e9e85..50442138 100644 --- a/tests/physics/test_utils.py +++ b/tests/physics/test_utils.py @@ -1,7 +1,11 @@ import numpy as np import pytest -from pooltool.physics.utils import surface_velocity, tangent_surface_velocity +from pooltool.physics.utils import ( + get_airborne_time, + surface_velocity, + tangent_surface_velocity, +) def test_surface_velocity_no_angular_velocity(): @@ -144,3 +148,68 @@ def test_smoke_test(v, w, d, expected_surface, expected_tangent): v_tangent = tangent_surface_velocity(rvw, d, R) assert np.isclose(v_tangent, expected_tangent).all() + + +@pytest.mark.parametrize( + "rvw,R,g,expected", + [ + # Zero gravity: no return to table. + ( + np.array( + [ + [0.0, 0.0, 1.1], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float64, + ), + 0.1, + 0.0, + np.inf, + ), + # Drop from apex (v_z = 0): t = sqrt(2 * (z - R) / g). + ( + np.array( + [ + [0.0, 0.0, 1.1], + [0.0, 0.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float64, + ), + 0.1, + 10.0, + 0.4472135955, + ), + # xy-velocity does not affect time-to-table. + ( + np.array( + [ + [0.0, 0.0, 1.1], + [1.0, 1.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float64, + ), + 0.1, + 10.0, + 0.4472135955, + ), + # Ball at z=R with downward velocity: already touching, t = 0. + ( + np.array( + [ + [0.0, 0.0, 0.1], + [0.0, -1.0, 0.0], + [0.0, 0.0, 0.0], + ], + dtype=np.float64, + ), + 0.1, + 10.0, + 0.0, + ), + ], +) +def test_get_airborne_time(rvw, R, g, expected): + assert np.isclose(get_airborne_time(rvw, R, g), expected) From 355646a4904f6e2dc19b9492cd27d42afa25d245 Mon Sep 17 00:00:00 2001 From: Evan Kiefl Date: Sun, 17 May 2026 16:15:33 -0700 Subject: [PATCH 2/2] Use quadratic solver --- pooltool/physics/utils.py | 24 +++++++++--------------- 1 file changed, 9 insertions(+), 15 deletions(-) diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 97c38843..125d3f7b 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -1,8 +1,11 @@ +import math + import numpy as np from numba import jit from numpy.typing import NDArray import pooltool.constants as const +from pooltool.ptmath.roots import quadratic from pooltool.ptmath.utils import coordinate_rotation, cross, norm3d, unit_vector @@ -53,28 +56,19 @@ def get_u_vec( def get_airborne_time(rvw: NDArray[np.float64], R: float, g: float) -> float: """Time until an airborne ball's bottom touches the table plane (z = R). - Returns ``np.inf`` if no future intersection exists (either gravity is zero, or the - discriminant of the quadratic ``-0.5 * g * t**2 + v_z * t + (z - R) = 0`` is - negative). + Solves ``-0.5 * g * t**2 + v_z * t + (z - R) = 0`` and returns the later root + (the descending-leg intersection). Returns ``np.inf`` when gravity is zero, or + when the discriminant is negative. """ if g == 0.0: return np.inf - A = -0.5 * g - B = rvw[1, 2] - C = rvw[0, 2] - R - - D = B**2 - 4 * A * C + t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R) - if D < 0: - # Only consider real roots. + if math.isnan(t1): return np.inf - # This is the only possible root assuming the ball starts above the table and - # acceleration due to gravity is towards table. - t_f = -(B + np.sqrt(D)) / (2 * A) - - return t_f + return max(t1, t2) @jit(nopython=True, cache=const.use_numba_cache)