Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 38 additions & 32 deletions pooltool/evolution/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,52 +11,58 @@

@attrs.define
class SimulationEngine:
"""A pluggable bundle of strategies used by the simulator.
"""A bundle of physics strategies used by the simulator.

Holds the strategies that define how a simulation is carried out: how events are
detected and how they are resolved. The simulator is handed an instance of this
class and routes work to its components.
Holds the resolver (pluggable per-event-type collision strategies) and the detector.
The simulator is handed an instance and routes work to its components.

Attributes:
is_3d:
Whether the simulation supports the airborne motion state and
ball-table events. Validated at construction against the
dimensionality capability (``dim``) of every bundled strategy in
``resolver``.
resolver:
The strategy responsible for resolving events.
Pluggable bundle of event-resolution strategies. Each strategy
declares a ``Dim`` capability (except ``ball_table``).
detector:
The strategy responsible for detecting the next event.
is_3d:
Whether the simulation supports the airborne motion state and ball-table
events. Validated at construction against the dimensionality capability
(``dim``) of every bundled strategy in ``resolver`` and ``detector``.
Canonical event detector. Not constructor-passable — built from
``is_3d`` automatically.
"""

resolver: Resolver = attrs.field(factory=Resolver.default)
detector: EventDetector = attrs.field(factory=EventDetector.default)
is_3d: bool = False
resolver: Resolver = attrs.field(factory=Resolver.default)
detector: EventDetector = attrs.field(init=False)

@detector.default # type: ignore
def _default_detector(self) -> EventDetector:
return EventDetector(is_3d=self.is_3d)

def __attrs_post_init__(self) -> None:
self._validate_dimensionality()

def _validate_dimensionality(self) -> None:
required = Dim.THREE if self.is_3d else Dim.TWO
for bundle in (self.resolver, self.detector):
for field in attrs.fields(type(bundle)):
if field.name in SKIP_DIMENSION:
continue
strategy = getattr(bundle, field.name)
if not attrs.has(type(strategy)):
continue
if not hasattr(strategy, "dim"):
raise AttributeError(
f"{type(bundle).__name__}.{field.name} "
f"({type(strategy).__name__}) is missing required "
f"'dim' attribute"
)
if strategy.dim not in (required, Dim.BOTH):
raise ValueError(
f"{type(bundle).__name__}.{field.name} "
f"({type(strategy).__name__}) has dim={strategy.dim}, "
f"incompatible with is_3d={self.is_3d}; "
f"expected {required} or {Dim.BOTH}"
)

for field in attrs.fields(type(self.resolver)):
if field.name in SKIP_DIMENSION:
continue
strategy = getattr(self.resolver, field.name)
if not attrs.has(type(strategy)):
continue
if not hasattr(strategy, "dim"):
raise AttributeError(
f"Resolver.{field.name} "
f"({type(strategy).__name__}) is missing required "
f"'dim' attribute"
)
if strategy.dim not in (required, Dim.BOTH):
raise ValueError(
f"Resolver.{field.name} "
f"({type(strategy).__name__}) has dim={strategy.dim}, "
f"incompatible with is_3d={self.is_3d}; "
f"expected {required} or {Dim.BOTH}"
)


__all__ = [
Expand Down
44 changes: 20 additions & 24 deletions pooltool/evolution/event_based/detect/__init__.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,35 @@
from pooltool.evolution.event_based.detect.ball_ball import (
BallBallDetection,
BallBallDetectionStrategy,
get_next_ball_ball_2d_event,
get_next_ball_ball_3d_event,
)
from pooltool.evolution.event_based.detect.ball_cushion import (
BallCCushionDetection,
BallCCushionDetectionStrategy,
BallLCushionDetection,
BallLCushionDetectionStrategy,
get_next_ball_circular_cushion_2d_event,
get_next_ball_circular_cushion_3d_event,
get_next_ball_linear_cushion_2d_event,
get_next_ball_linear_cushion_3d_event,
)
from pooltool.evolution.event_based.detect.ball_pocket import (
BallPocketDetection,
BallPocketDetectionStrategy,
get_next_ball_pocket_2d_event,
get_next_ball_pocket_3d_event,
)
from pooltool.evolution.event_based.detect.ball_table import (
BallTableDetection,
BallTableDetectionStrategy,
get_next_ball_table_event,
)
from pooltool.evolution.event_based.detect.detector import EventDetector
from pooltool.evolution.event_based.detect.stick_ball import (
StickBallDetection,
StickBallDetectionStrategy,
get_next_stick_ball_event,
)

__all__ = [
"EventDetector",
"BallBallDetection",
"BallBallDetectionStrategy",
"BallCCushionDetection",
"BallCCushionDetectionStrategy",
"BallLCushionDetection",
"BallLCushionDetectionStrategy",
"BallPocketDetection",
"BallPocketDetectionStrategy",
"BallTableDetection",
"BallTableDetectionStrategy",
"StickBallDetection",
"StickBallDetectionStrategy",
"get_next_ball_ball_2d_event",
"get_next_ball_ball_3d_event",
"get_next_ball_circular_cushion_2d_event",
"get_next_ball_circular_cushion_3d_event",
"get_next_ball_linear_cushion_2d_event",
"get_next_ball_linear_cushion_3d_event",
"get_next_ball_pocket_2d_event",
"get_next_ball_pocket_3d_event",
"get_next_ball_table_event",
"get_next_stick_ball_event",
]
146 changes: 67 additions & 79 deletions pooltool/evolution/event_based/detect/ball_ball.py
Original file line number Diff line number Diff line change
@@ -1,93 +1,81 @@
from __future__ import annotations

from itertools import combinations
from typing import Protocol

import attrs
import numpy as np

import pooltool.constants as const
import pooltool.ptmath as ptmath
from pooltool.events import Event, EventType, ball_ball_collision, null_event
from pooltool.evolution.event_based.cache import CollisionCache
from pooltool.physics.dimensionality import Dim
from pooltool.physics.motion.solve import ball_ball_collision_time
from pooltool.system.datatypes import System


class BallBallDetectionStrategy(Protocol):
"""Ball-ball detection models must satisfy this protocol."""

dim: Dim

def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ...


@attrs.define
class BallBallDetection:
"""Detects the next ball-ball collision in the system."""

dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False)

def get_next(self, shot: System, collision_cache: CollisionCache) -> Event:
cache = collision_cache.times.setdefault(EventType.BALL_BALL, {})

for ball1, ball2 in combinations(shot.balls.values(), 2):
ball_pair = (ball1.id, ball2.id)
if ball_pair in cache:
continue

ball1_state = ball1.state
ball1_params = ball1.params

ball2_state = ball2.state
ball2_params = ball2.params

if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed:
cache[ball_pair] = np.inf
elif (
ball1_state.s in const.nontranslating
and ball2_state.s in const.nontranslating
):
cache[ball_pair] = np.inf
elif ptmath.is_overlapping(
ball1_state.rvw,
ball2_state.rvw,
ball1_params.R,
ball2_params.R,
):
cache[ball_pair] = shot.t
else:
dtau_E = ball_ball_collision_time(
rvw1=ball1_state.rvw,
rvw2=ball2_state.rvw,
s1=ball1_state.s,
s2=ball2_state.s,
mu1=(
ball1_params.u_s
if ball1_state.s == const.sliding
else ball1_params.u_r
),
mu2=(
ball2_params.u_s
if ball2_state.s == const.sliding
else ball2_params.u_r
),
m1=ball1_params.m,
m2=ball2_params.m,
g1=ball1_params.g,
g2=ball2_params.g,
R=ball1_params.R,
)
cache[ball_pair] = shot.t + dtau_E

if not cache:
return null_event(np.inf)

ball_pair = min(cache, key=lambda k: cache[k])

return ball_ball_collision(
ball1=shot.balls[ball_pair[0]],
ball2=shot.balls[ball_pair[1]],
time=cache[ball_pair],
)
def get_next_ball_ball_2d_event(shot: System, collision_cache: CollisionCache) -> Event:
"""Detect the next ball-ball collision in 2D mode."""
cache = collision_cache.times.setdefault(EventType.BALL_BALL, {})

for ball1, ball2 in combinations(shot.balls.values(), 2):
ball_pair = (ball1.id, ball2.id)
if ball_pair in cache:
continue

ball1_state = ball1.state
ball1_params = ball1.params

ball2_state = ball2.state
ball2_params = ball2.params

if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed:
cache[ball_pair] = np.inf
elif (
ball1_state.s in const.nontranslating
and ball2_state.s in const.nontranslating
):
cache[ball_pair] = np.inf
elif ptmath.is_overlapping(
ball1_state.rvw,
ball2_state.rvw,
ball1_params.R,
ball2_params.R,
):
cache[ball_pair] = shot.t
else:
dtau_E = ball_ball_collision_time(
rvw1=ball1_state.rvw,
rvw2=ball2_state.rvw,
s1=ball1_state.s,
s2=ball2_state.s,
mu1=(
ball1_params.u_s
if ball1_state.s == const.sliding
else ball1_params.u_r
),
mu2=(
ball2_params.u_s
if ball2_state.s == const.sliding
else ball2_params.u_r
),
m1=ball1_params.m,
m2=ball2_params.m,
g1=ball1_params.g,
g2=ball2_params.g,
R=ball1_params.R,
)
cache[ball_pair] = shot.t + dtau_E

if not cache:
return null_event(np.inf)

ball_pair = min(cache, key=lambda k: cache[k])

return ball_ball_collision(
ball1=shot.balls[ball_pair[0]],
ball2=shot.balls[ball_pair[1]],
time=cache[ball_pair],
)


def get_next_ball_ball_3d_event(shot: System, collision_cache: CollisionCache) -> Event:
raise NotImplementedError("3D ball-ball detection has not been vendored yet")
Loading
Loading