Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion docs/resources/custom_physics.md
Original file line number Diff line number Diff line change
Expand Up @@ -201,7 +201,11 @@ dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False)

- `Dim.TWO` — your model is safe only when [](#pooltool.evolution.engine.SimulationEngine)'s `is_3d` is `False`. Use this if your model assumes balls are on the table surface (z=R, vz=0).
- `Dim.THREE` — your model is safe only when `is_3d` is `True`. Use this if your model assumes 3D ball state (e.g. it produces or handles airborne balls).
- `Dim.BOTH` — your model behaves identically in either mode. This is a strong promise: a `Dim.BOTH` model doesn't care or know whether it's handling a 2D/3D simulation.
- `Dim.BOTH` — your model is safe in either mode. It may still take different code paths depending on the input it receives (e.g. a branch on `state == const.airborne` is dead in 2D and live in 3D), as long as neither path is incorrect for the mode it runs under.

:::{note}
**Ball-table resolvers are an exception:** they do not declare a `dim` attribute. Ball-table events do not exist as a concept in 2D, so a 2D-vs-3D capability declaration is not meaningful for them.
:::

Great, now we are done with the boilerplate code. But `resolve` currently does *nothing*, it just returns what is handed to it. Let's change that.

Expand Down
6 changes: 3 additions & 3 deletions pooltool/ani/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from pathlib import Path

import pooltool as pt
import pooltool
from pooltool.utils import panda_path

menu_text_scale = 0.07
Expand Down Expand Up @@ -41,9 +41,9 @@
"shadow_scale_amplitude": 0.4,
}

model_dir: Path = Path(pt.__file__).parent / "models"
model_dir: Path = Path(pooltool.__file__).parent / "models"

logo_dir = Path(pt.__file__).parent / "logo"
logo_dir = Path(pooltool.__file__).parent / "logo"
logo_paths = {
"default": panda_path(logo_dir / "logo.png"),
"small": panda_path(logo_dir / "logo_small.png"),
Expand Down
7 changes: 3 additions & 4 deletions pooltool/evolution/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,9 +4,8 @@

import attrs

from pooltool.evolution.event_based.config import DORMANT_IN_2D
from pooltool.evolution.event_based.detect import EventDetector
from pooltool.physics.dimensionality import Dim
from pooltool.physics.dimensionality import SKIP_DIMENSION, Dim
from pooltool.physics.resolve import Resolver


Expand Down Expand Up @@ -40,11 +39,11 @@ def _validate_dimensionality(self) -> None:
required = Dim.THREE if self.is_3d else Dim.TWO
for bundle in (self.resolver, self.detector):
for field in attrs.fields(type(bundle)):
if field.name in SKIP_DIMENSION:
continue
strategy = getattr(bundle, field.name)
if not attrs.has(type(strategy)):
continue
if not self.is_3d and field.name in DORMANT_IN_2D:
continue
if not hasattr(strategy, "dim"):
raise AttributeError(
f"{type(bundle).__name__}.{field.name} "
Expand Down
6 changes: 4 additions & 2 deletions pooltool/evolution/event_based/cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,9 @@ def create(cls, shot: System) -> TransitionCache:


def _next_transition(ball: Ball) -> Event:
if ball.state.s == const.stationary or ball.state.s == const.pocketed:
if ball.state.s in {const.stationary, const.pocketed, const.airborne}:
# Stationary and airborne states can only be changed via collisions, and
# pocketed states can never be changed.
return null_event(time=np.inf)

elif ball.state.s == const.spinning:
Expand Down Expand Up @@ -131,7 +133,7 @@ class CollisionCache:
event caching, see :class:`TransitionCache`.
"""

times: dict[EventType, dict[tuple[str, str], float]] = attrs.field(factory=dict)
times: dict[EventType, dict[tuple[str, ...], float]] = attrs.field(factory=dict)

@property
def size(self) -> int:
Expand Down
8 changes: 1 addition & 7 deletions pooltool/evolution/event_based/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,16 +6,10 @@
EventType.BALL_LINEAR_CUSHION,
EventType.BALL_CIRCULAR_CUSHION,
EventType.BALL_POCKET,
EventType.BALL_TABLE,
EventType.STICK_BALL,
EventType.SPINNING_STATIONARY,
EventType.ROLLING_STATIONARY,
EventType.ROLLING_SPINNING,
EventType.SLIDING_ROLLING,
}

DORMANT_IN_2D: frozenset[str] = frozenset({"ball_table"})
"""Resolver/EventDetector field names that are dormant in 2D mode because the detection
layer doesn't emit their associated event types. In 2D, the ``dim`` of these fields is
not validated against ``SimulationEngine.is_3d`` - any tag is safe because the strategy
will never be invoked. In 3D, they are validated normally so a misdeclared ball-table
strategy is still caught."""
6 changes: 6 additions & 0 deletions pooltool/evolution/event_based/detect/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@
BallPocketDetection,
BallPocketDetectionStrategy,
)
from pooltool.evolution.event_based.detect.ball_table import (
BallTableDetection,
BallTableDetectionStrategy,
)
from pooltool.evolution.event_based.detect.detector import EventDetector
from pooltool.evolution.event_based.detect.stick_ball import (
StickBallDetection,
Expand All @@ -28,6 +32,8 @@
"BallLCushionDetectionStrategy",
"BallPocketDetection",
"BallPocketDetectionStrategy",
"BallTableDetection",
"BallTableDetectionStrategy",
"StickBallDetection",
"StickBallDetectionStrategy",
]
47 changes: 47 additions & 0 deletions pooltool/evolution/event_based/detect/ball_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from __future__ import annotations

from typing import Protocol

import attrs

from pooltool.events import Event, EventType, ball_table_collision
from pooltool.evolution.event_based.cache import CollisionCache
from pooltool.physics.motion.solve import ball_table_collision_time
from pooltool.system.datatypes import System


class BallTableDetectionStrategy(Protocol):
"""Ball-table detection models must satisfy this protocol.

Unlike the other detection-strategy protocols, this one does not declare a
``dim`` attribute.
"""

def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ...


@attrs.define
class BallTableDetection:
"""Detects the next ball-table collision in the system."""

def get_next(self, shot: System, collision_cache: CollisionCache) -> Event:
cache = collision_cache.times.setdefault(EventType.BALL_TABLE, {})

for ball in shot.balls.values():
obj_ids = (ball.id,)
if obj_ids in cache:
continue
dtau_E = ball_table_collision_time(
rvw=ball.state.rvw,
s=ball.state.s,
g=ball.params.g,
R=ball.params.R,
)
cache[obj_ids] = shot.t + dtau_E

obj_ids = min(cache, key=lambda k: cache[k])

return ball_table_collision(
ball=shot.balls[obj_ids[0]],
time=cache[obj_ids],
)
19 changes: 18 additions & 1 deletion pooltool/evolution/event_based/detect/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
BallLCushionDetection,
)
from pooltool.evolution.event_based.detect.ball_pocket import BallPocketDetection
from pooltool.evolution.event_based.detect.ball_table import BallTableDetection
from pooltool.evolution.event_based.detect.stick_ball import StickBallDetection
from pooltool.physics.utils import get_ball_energy
from pooltool.system.datatypes import System
Expand All @@ -27,7 +28,7 @@ def _get_event_priority(event: Event, shot: System) -> tuple[int, float]:
Priority tiers:
- Tier 1: STICK_BALL (always first)
- Tier 2: Transitions and BALL_POCKET (can resolve without affecting others)
- Tier 3: BALL_BALL and ball-cushion collisions
- Tier 3: BALL_BALL, ball-cushion collisions, and BALL_TABLE

Args:
event: The event to compute priority for.
Expand Down Expand Up @@ -69,6 +70,17 @@ def _get_event_priority(event: Event, shot: System) -> tuple[int, float]:
energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m)
return (3, energy)

# TODO: tier and energy choice for BALL_TABLE has not been well thought
# through or tested. Mirroring the cushion-collision semantics, but
# BALL_TABLE-vs-other ties only become real once 3D activation lands and
# airborne balls actually arise. Revisit once break / aerial trajectories
# exercise this path.
if event_type == EventType.BALL_TABLE:
ball_id = event.ids[0]
ball = shot.balls[ball_id]
energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m)
return (3, energy)

return (99, 0.0)


Expand All @@ -95,6 +107,9 @@ class EventDetector:
Strategy for detecting the next ball-vs-circular-cushion-segment collision.
ball_pocket:
Strategy for detecting the next ball-pocket collision.
ball_table:
Strategy for detecting the next ball-table collision (airborne ball
landing on the table surface).
"""

stick_ball: StickBallDetection = attrs.field(factory=StickBallDetection)
Expand All @@ -106,6 +121,7 @@ class EventDetector:
factory=BallCCushionDetection
)
ball_pocket: BallPocketDetection = attrs.field(factory=BallPocketDetection)
ball_table: BallTableDetection = attrs.field(factory=BallTableDetection)

@classmethod
def default(cls) -> EventDetector:
Expand Down Expand Up @@ -143,6 +159,7 @@ def get_next_event(
candidates.append(self.ball_circular_cushion.get_next(shot, collision_cache))
candidates.append(self.ball_linear_cushion.get_next(shot, collision_cache))
candidates.append(self.ball_pocket.get_next(shot, collision_cache))
candidates.append(self.ball_table.get_next(shot, collision_cache))

min_time = min(event.time for event in candidates)

Expand Down
7 changes: 7 additions & 0 deletions pooltool/physics/dimensionality.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,3 +26,10 @@ class Dim(StrEnum):
TWO = auto()
THREE = auto()
BOTH = auto()


SKIP_DIMENSION: frozenset[str] = frozenset({"ball_table"})
"""Resolver/EventDetector field names whose strategies don't carry a ``dim``
attribute. ``SimulationEngine._validate_dimensionality`` skips these fields
entirely (in either mode). Used for slots whose events have no meaning in 2D
(currently just ``ball_table``: airborne balls only exist in 3D)."""
19 changes: 18 additions & 1 deletion pooltool/physics/motion/solve.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@
import pooltool.constants as const
import pooltool.physics.evolve as evolve
import pooltool.ptmath as ptmath
from pooltool.physics.utils import rel_velocity
from pooltool.physics.utils import get_airborne_time, rel_velocity
from pooltool.ptmath.roots import quartic
from pooltool.ptmath.roots.core import get_real_positive_smallest_root

Expand Down Expand Up @@ -421,3 +421,20 @@ def ball_pocket_collision_time(
)
)
)


@jit(nopython=True, cache=const.use_numba_cache)
def ball_table_collision_time(
rvw: NDArray[np.float64],
s: int,
g: float,
R: float,
) -> float:
"""Time until an airborne ball's bottom touches the table plane.

Returns ``np.inf`` if the ball is not airborne (no ball-table collision can
occur for any other motion state).
"""
if s != const.airborne:
return np.inf
return get_airborne_time(rvw=rvw, R=R, g=g)
9 changes: 5 additions & 4 deletions pooltool/physics/resolve/ball_table/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@

import pooltool.constants as const
from pooltool.objects.ball.datatypes import Ball
from pooltool.physics.dimensionality import Dim
from pooltool.physics.utils import on_table, rel_velocity
from pooltool.ptmath.utils import norm2d, norm3d

Expand Down Expand Up @@ -40,15 +39,17 @@ def final_ball_motion_state(rvw: NDArray[np.float64], R: float) -> int:


class _BaseStrategy(Protocol):
dim: Dim

def resolve(self, ball: Ball, inplace: bool = False) -> Ball: ...

def make_kiss(self, ball: Ball) -> Ball: ...


class BallTableCollisionStrategy(_BaseStrategy, Protocol):
"""Ball-table collision models must satisfy this protocol"""
"""Ball-table collision models must satisfy this protocol.

Unlike the other resolver-strategy protocols, this one does not declare a
``dim`` attribute.
"""

def solve(self, ball: Ball) -> Ball:
"""Resolves a ball-table collision"""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,6 @@
import pooltool.physics as physics
import pooltool.ptmath as ptmath
from pooltool.objects.ball.datatypes import Ball, BallState
from pooltool.physics.dimensionality import Dim
from pooltool.physics.resolve.ball_table.core import (
CoreBallTableCollision,
bounce_height,
Expand Down Expand Up @@ -74,7 +73,6 @@ class FrictionalInelasticTable(CoreBallTableCollision):
model: BallTableModel = attrs.field(
default=BallTableModel.FRICTIONAL_INELASTIC, init=False, repr=False
)
dim: Dim = attrs.field(default=Dim.THREE, init=False, repr=False)

def solve(self, ball: Ball) -> Ball:
"""Resolves the collision."""
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import attrs

from pooltool.objects.ball.datatypes import Ball, BallState
from pooltool.physics.dimensionality import Dim
from pooltool.physics.resolve.ball_table.core import (
CoreBallTableCollision,
bounce_height,
Expand Down Expand Up @@ -36,7 +35,6 @@ class FrictionlessInelasticTable(CoreBallTableCollision):
model: BallTableModel = attrs.field(
default=BallTableModel.FRICTIONLESS_INELASTIC, init=False, repr=False
)
dim: Dim = attrs.field(default=Dim.THREE, init=False, repr=False)

def solve(self, ball: Ball) -> Ball:
"""Resolves the collision."""
Expand Down
56 changes: 56 additions & 0 deletions tests/evolution/event_based/test_ball_table.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
import numpy as np
import pytest

import pooltool.constants as const
from pooltool.events import EventType
from pooltool.evolution.event_based.cache import CollisionCache
from pooltool.evolution.event_based.detect.ball_table import BallTableDetection
from pooltool.physics.utils import get_airborne_time
from pooltool.system.datatypes import System


@pytest.fixture
def system() -> System:
return System.example()


def test_no_airborne_balls_returns_inf_time(system: System):
"""In a default 2D scene no ball is airborne, so the emitted event has time=inf."""
event = BallTableDetection().get_next(system, CollisionCache())
assert event.event_type == EventType.BALL_TABLE
assert event.time == np.inf


def test_airborne_ball_returns_finite_time(system: System):
"""An airborne ball at apex over the table returns the physics-derived drop time."""
ball = next(iter(system.balls.values()))
ball.state.rvw[0, 2] = ball.params.R + 0.1
ball.state.rvw[1, 2] = 0.0
ball.state.s = const.airborne

event = BallTableDetection().get_next(system, CollisionCache())

expected = get_airborne_time(ball.state.rvw, ball.params.R, ball.params.g)
assert event.event_type == EventType.BALL_TABLE
assert event.time == pytest.approx(expected)


def test_returns_soonest_ball(system: System):
"""When multiple balls are airborne, the one with the shortest drop time wins."""
balls = list(system.balls.values())
assert len(balls) >= 2

high, low = balls[0], balls[1]

high.state.rvw[0, 2] = high.params.R + 0.5
high.state.rvw[1, 2] = 0.0
high.state.s = const.airborne

low.state.rvw[0, 2] = low.params.R + 0.05
low.state.rvw[1, 2] = 0.0
low.state.s = const.airborne

event = BallTableDetection().get_next(system, CollisionCache())

assert event.event_type == EventType.BALL_TABLE
assert event.ids[0] == low.id
Loading
Loading