Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 7 additions & 1 deletion pooltool/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,15 +45,21 @@

A ball with this motion state is in a pocket.
"""
airborne: int = 5
"""The airborne motion state label

A ball with this motion state is airborne.
"""

state_dict: dict[int, str] = {
0: "stationary",
1: "spinning",
2: "sliding",
3: "rolling",
4: "pocketed",
5: "airborne",
}

on_table = {stationary, spinning, sliding, rolling}
nontranslating = {stationary, spinning, pocketed}
energetic = {spinning, sliding, rolling}
energetic = {spinning, sliding, rolling, airborne}
2 changes: 2 additions & 0 deletions pooltool/events/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
ball_circular_cushion_collision,
ball_linear_cushion_collision,
ball_pocket_collision,
ball_table_collision,
null_event,
rolling_spinning_transition,
rolling_stationary_transition,
Expand Down Expand Up @@ -41,6 +42,7 @@
"ball_circular_cushion_collision",
"ball_pocket_collision",
"stick_ball_collision",
"ball_table_collision",
"spinning_stationary_transition",
"rolling_stationary_transition",
"rolling_spinning_transition",
Expand Down
5 changes: 5 additions & 0 deletions pooltool/events/datatypes.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,8 @@ class EventType(strenum.StrEnum):
the *point of no return*.
STICK_BALL:
A cue-stick ball collision.
BALL_TABLE:
A ball collision into the table surface.
SPINNING_STATIONARY:
A ball transition from spinning to stationary.
ROLLING_STATIONARY:
Expand All @@ -51,6 +53,7 @@ class EventType(strenum.StrEnum):
BALL_CIRCULAR_CUSHION = strenum.auto()
BALL_POCKET = strenum.auto()
STICK_BALL = strenum.auto()
BALL_TABLE = strenum.auto()
SPINNING_STATIONARY = strenum.auto()
ROLLING_STATIONARY = strenum.auto()
ROLLING_SPINNING = strenum.auto()
Expand All @@ -64,6 +67,7 @@ def is_collision(self) -> bool:
EventType.BALL_LINEAR_CUSHION,
EventType.BALL_POCKET,
EventType.STICK_BALL,
EventType.BALL_TABLE,
}

def is_transition(self) -> bool:
Expand All @@ -85,6 +89,7 @@ def has_ball(self) -> bool:
EventType.BALL_CIRCULAR_CUSHION,
EventType.BALL_POCKET,
EventType.STICK_BALL,
EventType.BALL_TABLE,
}
or self.is_transition()
)
Expand Down
17 changes: 17 additions & 0 deletions pooltool/events/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,23 @@ def stick_ball_collision(
)


def ball_table_collision(ball: Ball, time: float, set_initial: bool = False) -> Event:
"""Create a ball-table collision.

Note:
- Since information about the ball-table interaction is stored exclusively in
the ball (not the table), and since the table is a large composite of
individual objects which is costly to serialize, the table is not stored as an
agent of the returned event and is therefore not accepted as an argument of
this function.
"""
return Event(
event_type=EventType.BALL_TABLE,
agents=(Agent.from_object(ball, set_initial=set_initial),),
time=time,
)


def spinning_stationary_transition(
ball: Ball, time: float, set_initial: bool = False
) -> Event:
Expand Down
1 change: 1 addition & 0 deletions pooltool/events/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
EventType.BALL_CIRCULAR_CUSHION: {0},
EventType.BALL_POCKET: {0},
EventType.STICK_BALL: {1},
EventType.BALL_TABLE: {0},
EventType.SPINNING_STATIONARY: {0},
EventType.ROLLING_STATIONARY: {0},
EventType.ROLLING_SPINNING: {0},
Expand Down
2 changes: 2 additions & 0 deletions pooltool/physics/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
ball_transition_models,
)
from pooltool.physics.utils import (
get_airborne_time,
get_ball_energy,
get_roll_time,
get_slide_time,
Expand All @@ -61,6 +62,7 @@
"get_slide_time",
"get_roll_time",
"get_spin_time",
"get_airborne_time",
"get_ball_energy",
"ball_ball_models",
"BallBallFrictionModel",
Expand Down
26 changes: 26 additions & 0 deletions pooltool/physics/evolve/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def evolve_ball_motion(
if state == const.stationary or state == const.pocketed:
return rvw, state

if state == const.airborne:
return _evolve_airborne_state(rvw, g, t), const.airborne

if state == const.sliding:
dtau_E_slide = get_slide_time(rvw, R, u_s, g)

Expand Down Expand Up @@ -183,3 +186,26 @@ def _evolve_perpendicular_spin_state(
) -> NDArray[np.float64]:
rvw[2, 2] = _evolve_perpendicular_spin_component(rvw[2, 2], R, u_sp, g, t)
return rvw


@jit(nopython=True, cache=const.use_numba_cache)
def _evolve_airborne_state(
rvw: NDArray[np.float64], g: float, t: float
) -> NDArray[np.float64]:
"""Parabolic evolution under gravity. Angular velocity is conserved."""
if t == 0:
return rvw

r_0, v_0, w_0 = rvw

g_vec = np.array([0.0, 0.0, g], dtype=np.float64)

r = r_0 + v_0 * t - 0.5 * g_vec * t**2
v = v_0 - g_vec * t

new_rvw = np.empty((3, 3), dtype=np.float64)
new_rvw[0, :] = r
new_rvw[1, :] = v
new_rvw[2, :] = w_0

return new_rvw
22 changes: 22 additions & 0 deletions pooltool/physics/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,11 @@
import math

import numpy as np
from numba import jit
from numpy.typing import NDArray

import pooltool.constants as const
from pooltool.ptmath.roots import quadratic
from pooltool.ptmath.utils import coordinate_rotation, cross, norm3d, unit_vector


Expand Down Expand Up @@ -49,6 +52,25 @@ def get_u_vec(
return coordinate_rotation(unit_vector(rel_vel), -phi)


@jit(nopython=True, cache=const.use_numba_cache)
def get_airborne_time(rvw: NDArray[np.float64], R: float, g: float) -> float:
"""Time until an airborne ball's bottom touches the table plane (z = R).

Solves ``-0.5 * g * t**2 + v_z * t + (z - R) = 0`` and returns the later root
(the descending-leg intersection). Returns ``np.inf`` when gravity is zero, or
when the discriminant is negative.
"""
if g == 0.0:
return np.inf

t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R)

if math.isnan(t1):
return np.inf

return max(t1, t2)


@jit(nopython=True, cache=const.use_numba_cache)
def get_slide_time(rvw: NDArray[np.float64], R: float, u_s: float, g: float) -> float:
if u_s == 0.0:
Expand Down
145 changes: 145 additions & 0 deletions tests/physics/evolve/test_airborne.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,145 @@
import numpy as np

from pooltool.physics.evolve import _evolve_airborne_state


def test_xy_velocity_conserved():
"""Test that the x- and y-components of the velocity remain unchanged as time evolves."""
r0 = np.array([0.0, 0.0, 0.0], dtype=np.float64)
v0 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64)
rvw0 = np.array([r0, v0, w0], dtype=np.float64)

g = 9.81
t = 1.0

rvw = _evolve_airborne_state(rvw0.copy(), g, t)

# Check if vx and vy are unchanged
np.testing.assert_almost_equal(
rvw[1, 0], v0[0], err_msg="X velocity changed unexpectedly."
)
np.testing.assert_almost_equal(
rvw[1, 1], v0[1], err_msg="Y velocity changed unexpectedly."
)


def test_angular_velocity_conserved():
"""Test that the angular velocity (w) is conserved and remains unchanged as time evolves."""
r0 = np.array([0.0, 0.0, 0.0], dtype=np.float64)
v0 = np.array([1.0, 2.0, 3.0], dtype=np.float64)
w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64)
rvw0 = np.array([r0, v0, w0], dtype=np.float64)

g = 9.81
t = 2.0

rvw = _evolve_airborne_state(rvw0.copy(), g, t)

# Check if angular velocity is unchanged
np.testing.assert_array_almost_equal(
rvw[2], w0, err_msg="Angular velocity changed unexpectedly."
)


def test_xy_displacement_linear():
"""Test that the xy-displacement changes linearly with time.

This assumes there is no air friction.

Equations:
r_x(t) = r_x(0) + v_x(0)*t
r_y(t) = r_y(0) + v_y(0)*t
"""
r0 = np.array([0.0, 0.0, 10.0], dtype=np.float64)
v0 = np.array([3.0, 4.0, 5.0], dtype=np.float64)
w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64)
rvw0 = np.array([r0, v0, w0], dtype=np.float64)

g = 9.81

for t in [0.0, 1.0, 2.0, 3.0]:
rvw = _evolve_airborne_state(rvw0.copy(), g, t)
expected_x = r0[0] + v0[0] * t
expected_y = r0[1] + v0[1] * t
np.testing.assert_almost_equal(
rvw[0, 0], expected_x, err_msg="X displacement not linear in time."
)
np.testing.assert_almost_equal(
rvw[0, 1], expected_y, err_msg="Y displacement not linear in time."
)


def test_gravity_direction():
"""Test that gravity affects the motion in the z-direction only.

Equations:
v_z(t) = v_z(0) - g*t
r_z(t) = r_z(0) + v_z(0)*t - (1/2)*g*t^2
"""
r0 = np.array([0.0, 0.0, 10.0], dtype=np.float64)
v0 = np.array([3.0, 4.0, 5.0], dtype=np.float64)
w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64)
rvw = np.array([r0, v0, w0], dtype=np.float64)

g = 9.81
t = 1.0

rvw_t = _evolve_airborne_state(rvw.copy(), g, t)
expected_z = r0[2] + v0[2] * t - 0.5 * g * t**2
expected_vz = v0[2] - g * t

np.testing.assert_almost_equal(
rvw_t[0, 2],
expected_z,
err_msg="Z displacement does not match gravity equation.",
)
np.testing.assert_almost_equal(
rvw_t[1, 2], expected_vz, err_msg="Z velocity does not match gravity equation."
)


def test_z_displacement_parabolic():
"""Test that z-displacement is parabolic.

Equations:
z(t) = z(0) + v_z(0)*t - (g/2)*t^2
"""
r0 = np.array([0.0, 0.0, 10.0], dtype=np.float64)
v0 = np.array([3.0, 4.0, 5.0], dtype=np.float64)
w0 = np.array([0.1, 0.2, 0.3], dtype=np.float64)
rvw = np.array([r0, v0, w0], dtype=np.float64)

g = 9.81
times = np.array([0.0, 1.0, 2.0, 3.0], dtype=np.float64)
z_values = []

for t in times:
rvw_t = _evolve_airborne_state(rvw.copy(), g, t)
z_values.append(rvw_t[0, 2])

z_values = np.array(z_values)
# Fit a quadratic to the computed z values
coeffs = np.polyfit(times, z_values, 2)
# The true quadratic form is: z(t) = z0 + v0_z t - (g/2)*t^2
# Coefficients from polyfit are in order: a*t^2 + b*t + c
# We know a should be -g/2, b should be v0_z, and c should be z0.

np.testing.assert_almost_equal(
coeffs[0],
-g / 2,
decimal=5,
err_msg="Quadratic coefficient a does not match -g/2.",
)
np.testing.assert_almost_equal(
coeffs[1],
v0[2],
decimal=5,
err_msg="Quadratic coefficient b does not match initial v_z.",
)
np.testing.assert_almost_equal(
coeffs[2],
r0[2],
decimal=5,
err_msg="Quadratic coefficient c does not match initial z.",
)
Loading
Loading