diff --git a/pooltool/evolution/event_based/_utils.py b/pooltool/evolution/event_based/_utils.py index 851fa773..f3e1c574 100644 --- a/pooltool/evolution/event_based/_utils.py +++ b/pooltool/evolution/event_based/_utils.py @@ -5,17 +5,19 @@ def _system_has_energy(system: System) -> bool: - """Return True if any ball in the system has kinetic energy. + """Return True if any ball in the system has nonzero mechanical energy. - Cue energy (e.g. ``system.cue.V0 > 0``) does not count. + Energy includes linear and rotational kinetic energy plus gravitational + potential energy (with PE=0 defined at the on-table resting height, + ``z = R``). Cue energy (e.g. ``system.cue.V0 > 0``) does not count. """ return any( - bool( - get_ball_energy( - ball.state.rvw, - ball.params.R, - ball.params.m, - ) + get_ball_energy( + ball.state.rvw, + ball.params.R, + ball.params.m, + ball.params.g, ) + > 0.0 for ball in system.balls.values() ) diff --git a/pooltool/evolution/event_based/detect/detector.py b/pooltool/evolution/event_based/detect/detector.py index 57f0e655..faf4072e 100644 --- a/pooltool/evolution/event_based/detect/detector.py +++ b/pooltool/evolution/event_based/detect/detector.py @@ -60,13 +60,17 @@ def _get_event_priority(event: Event, shot: System) -> tuple[int, float]: if event_type == EventType.BALL_POCKET: ball_id = event.ids[0] ball = shot.balls[ball_id] - energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + energy = get_ball_energy( + ball.state.rvw, ball.params.R, ball.params.m, ball.params.g + ) return (2, energy) if event_type.is_transition(): ball_id = event.ids[0] ball = shot.balls[ball_id] - energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + energy = get_ball_energy( + ball.state.rvw, ball.params.R, ball.params.m, ball.params.g + ) return (2, energy) if event_type == EventType.BALL_BALL: @@ -79,7 +83,9 @@ def _get_event_priority(event: Event, shot: System) -> tuple[int, float]: if event_type in (EventType.BALL_LINEAR_CUSHION, EventType.BALL_CIRCULAR_CUSHION): ball_id = event.ids[0] ball = shot.balls[ball_id] - energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + energy = get_ball_energy( + ball.state.rvw, ball.params.R, ball.params.m, ball.params.g + ) return (3, energy) # TODO: tier and energy choice for BALL_TABLE has not been well thought @@ -90,7 +96,9 @@ def _get_event_priority(event: Event, shot: System) -> tuple[int, float]: if event_type == EventType.BALL_TABLE: ball_id = event.ids[0] ball = shot.balls[ball_id] - energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + energy = get_ball_energy( + ball.state.rvw, ball.params.R, ball.params.m, ball.params.g + ) return (3, energy) return (99, 0.0) diff --git a/pooltool/evolution/event_based/introspection.py b/pooltool/evolution/event_based/introspection.py index 0bd7fb14..e0607557 100644 --- a/pooltool/evolution/event_based/introspection.py +++ b/pooltool/evolution/event_based/introspection.py @@ -31,6 +31,7 @@ ball_circular_cushion_collision, ball_linear_cushion_collision, ball_pocket_collision, + ball_table_collision, stick_ball_collision, ) from pooltool.evolution.engine import SimulationEngine @@ -48,11 +49,6 @@ def _get_collision_events_from_cache( system: System, cache: CollisionCache ) -> list[Event]: - # TODO: BALL_TABLE entries in the cache are not reconstructed here. In 2D - # mode this is harmless (the detector doesn't populate the BALL_TABLE - # bucket). When 3D activation lands, prospective BALL_TABLE events will be - # silently missed from get_prospective_events(). Add a branch that builds - # ball_table_collision(ball, time) from each (ball_id,) key. events = [] if EventType.BALL_BALL in cache.times: @@ -109,6 +105,15 @@ def _get_collision_events_from_cache( ) ) + if EventType.BALL_TABLE in cache.times: + for (ball_id,), time in cache.times[EventType.BALL_TABLE].items(): + events.append( + ball_table_collision( + ball=system.balls[ball_id], + time=time, + ) + ) + return events diff --git a/pooltool/physics/resolve/ball_pocket/__init__.py b/pooltool/physics/resolve/ball_pocket/__init__.py index 1be56770..6df1a1fa 100644 --- a/pooltool/physics/resolve/ball_pocket/__init__.py +++ b/pooltool/physics/resolve/ball_pocket/__init__.py @@ -35,7 +35,7 @@ class CanonicalBallPocket: model: BallPocketModel = attrs.field( default=BallPocketModel.CANONICAL, init=False, repr=False ) - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + dim: Dim = attrs.field(default=Dim.BOTH, init=False, repr=False) def resolve( self, ball: Ball, pocket: Pocket, inplace: bool = False diff --git a/pooltool/physics/resolve/transition/__init__.py b/pooltool/physics/resolve/transition/__init__.py index 3bae61fa..0b095ca9 100644 --- a/pooltool/physics/resolve/transition/__init__.py +++ b/pooltool/physics/resolve/transition/__init__.py @@ -34,7 +34,7 @@ class CanonicalTransition: model: BallTransitionModel = attrs.field( default=BallTransitionModel.CANONICAL, init=False, repr=False ) - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + dim: Dim = attrs.field(default=Dim.BOTH, init=False, repr=False) def resolve(self, ball: Ball, transition: EventType, inplace: bool = False) -> Ball: if not inplace: diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 3ee5b725..1208f698 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -96,19 +96,18 @@ def get_spin_time(rvw: NDArray[np.float64], R: float, u_sp: float, g: float) -> return np.abs(w[2]) * 2 / 5 * R / u_sp / g -def get_ball_energy(rvw: NDArray[np.float64], R: float, m: float) -> float: - """Get the energy of a ball +def get_ball_energy(rvw: NDArray[np.float64], R: float, m: float, g: float) -> float: + """Get the energy of a ball. - Currently calculating linear and rotational kinetic energy. Need to add potential - energy if z-axis is freed + Sum of linear kinetic, rotational kinetic, and gravitational potential energy. + Potential energy is defined relative to a ball at rest on the table (``z = R``), + so a ball sitting on the table contributes zero energy. """ - # Linear LKE = m * norm3d(rvw[1]) ** 2 / 2 - - # Rotational RKE = (2 / 5 * m * R**2) * norm3d(rvw[2]) ** 2 / 2 + PE = m * g * (rvw[0, 2] - R) - return LKE + RKE + return LKE + RKE + PE @jit(nopython=True, cache=const.use_numba_cache) diff --git a/sandbox/airborne_demos.py b/sandbox/airborne_demos.py index 18a61967..01cad57a 100644 --- a/sandbox/airborne_demos.py +++ b/sandbox/airborne_demos.py @@ -45,22 +45,6 @@ def _build_3d_engine() -> SimulationEngine: return SimulationEngine(resolver=resolver, is_3d=True) -def _empty_cue() -> Cue: - """A cue with ``V0=0`` so the simulator doesn't fire a stick strike at t=0. - - The default ``Cue`` constructor sets ``V0=2.0``; the stick-ball detector - fires whenever ``V0 > 0`` and ``_system_has_energy`` reports false. For - a ball that's airborne but momentarily at apex (vz=0), kinetic energy is - zero and the detector would trigger a strike on the handcrafted state. - Explicitly zero V0 to suppress that. - - TODO: Fix underlying problem in codebase - """ - cue = Cue(cue_ball_id="cue") - cue.V0 = 0 - return cue - - def drop() -> System: """Ball dropped from 0.3 m with a small horizontal nudge in +x.""" ball = Ball.create("cue", xy=(0.5, 0.5)) @@ -69,7 +53,7 @@ def drop() -> System: ball.state.s = const.airborne return System( - cue=_empty_cue(), + cue=Cue(cue_ball_id="cue"), table=Table.default(), balls=(ball,), ) @@ -83,7 +67,7 @@ def impulse_into() -> System: ball.state.s = const.airborne return System( - cue=_empty_cue(), + cue=Cue(cue_ball_id="cue"), table=Table.default(), balls=(ball,), ) diff --git a/tests/evolution/event_based/test_introspection.py b/tests/evolution/event_based/test_introspection.py index 8d354fc0..34cdac53 100644 --- a/tests/evolution/event_based/test_introspection.py +++ b/tests/evolution/event_based/test_introspection.py @@ -1,11 +1,16 @@ import tempfile from pathlib import Path +import numpy as np + +from pooltool.events import EventType, null_event +from pooltool.evolution.event_based.cache import CollisionCache, TransitionCache from pooltool.evolution.event_based.introspection import ( + SimulationSnapshot, SimulationSnapshotSequence, simulate_with_snapshots, ) -from pooltool.evolution.event_based.simulate import simulate +from pooltool.evolution.event_based.simulate import DEFAULT_ENGINE, simulate from pooltool.system.datatypes import System @@ -80,6 +85,31 @@ def test_post_resolve_of_n_equals_pre_evolve_of_n_plus_1(): assert post_resolve == pre_evolve_next +def test_get_prospective_events_includes_ball_table(): + """BALL_TABLE cache entries must surface in get_prospective_events. + + Regression: introspection's cache-to-events reconstruction was missing + the BALL_TABLE branch. Silent in 2D (which never populates the bucket); + surfaces once 3D activation lands. + """ + system = System.example() + cache = CollisionCache.create() + ball_id = next(iter(system.balls)) + cache.times[EventType.BALL_TABLE] = {(ball_id,): 0.123} + + snapshot = SimulationSnapshot( + step_number=0, + system=system, + next_event=null_event(np.inf), + collision_cache=cache, + transition_cache=TransitionCache.create(system), + engine=DEFAULT_ENGINE, + ) + + events = snapshot.get_prospective_events() + assert any(e.event_type == EventType.BALL_TABLE and e.time == 0.123 for e in events) + + def test_system_state_progression(): """Test the full progression: pre_evolve -> post_evolve -> post_resolve.""" system = System.example() diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index 3baf6bcb..da9291f3 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -508,6 +508,15 @@ def test_system_has_energy(): ball.state = ball.history[event_step] assert _system_has_energy(system) + # An airborne ball at apex (vz=0, KE=0) still has potential energy. + system = System.example() + ball = next(iter(system.balls.values())) + ball.state.rvw[0, 2] = 0.3 + ball.state.rvw[1] = 0.0 + ball.state.rvw[2] = 0.0 + ball.state.s = const.airborne + assert _system_has_energy(system) + def test_stick_ball_event_detection(): """Test that stick-ball events are properly detected as the first event diff --git a/tests/physics/resolve/ball_cushion/test_ball_cushion.py b/tests/physics/resolve/ball_cushion/test_ball_cushion.py index 488381ae..26de4c83 100644 --- a/tests/physics/resolve/ball_cushion/test_ball_cushion.py +++ b/tests/physics/resolve/ball_cushion/test_ball_cushion.py @@ -75,6 +75,7 @@ def test_energy( ball.state.rvw, ball.params.R, ball.params.m, + ball.params.g, ) # Resolve physics @@ -85,6 +86,7 @@ def test_energy( ball_after.state.rvw, ball_after.params.R, ball_after.params.m, + ball_after.params.g, ) assert np.isclose(initial_energy, final_energy) or final_energy <= initial_energy, (