diff --git a/pooltool/evolution/engine.py b/pooltool/evolution/engine.py index 4865b058..ca066071 100644 --- a/pooltool/evolution/engine.py +++ b/pooltool/evolution/engine.py @@ -11,52 +11,58 @@ @attrs.define class SimulationEngine: - """A pluggable bundle of strategies used by the simulator. + """A bundle of physics strategies used by the simulator. - Holds the strategies that define how a simulation is carried out: how events are - detected and how they are resolved. The simulator is handed an instance of this - class and routes work to its components. + Holds the resolver (pluggable per-event-type collision strategies) and the detector. + The simulator is handed an instance and routes work to its components. Attributes: + is_3d: + Whether the simulation supports the airborne motion state and + ball-table events. Validated at construction against the + dimensionality capability (``dim``) of every bundled strategy in + ``resolver``. resolver: - The strategy responsible for resolving events. + Pluggable bundle of event-resolution strategies. Each strategy + declares a ``Dim`` capability (except ``ball_table``). detector: - The strategy responsible for detecting the next event. - is_3d: - Whether the simulation supports the airborne motion state and ball-table - events. Validated at construction against the dimensionality capability - (``dim``) of every bundled strategy in ``resolver`` and ``detector``. + Canonical event detector. Not constructor-passable — built from + ``is_3d`` automatically. """ - resolver: Resolver = attrs.field(factory=Resolver.default) - detector: EventDetector = attrs.field(factory=EventDetector.default) is_3d: bool = False + resolver: Resolver = attrs.field(factory=Resolver.default) + detector: EventDetector = attrs.field(init=False) + + @detector.default # type: ignore + def _default_detector(self) -> EventDetector: + return EventDetector(is_3d=self.is_3d) def __attrs_post_init__(self) -> None: self._validate_dimensionality() def _validate_dimensionality(self) -> None: required = Dim.THREE if self.is_3d else Dim.TWO - for bundle in (self.resolver, self.detector): - for field in attrs.fields(type(bundle)): - if field.name in SKIP_DIMENSION: - continue - strategy = getattr(bundle, field.name) - if not attrs.has(type(strategy)): - continue - if not hasattr(strategy, "dim"): - raise AttributeError( - f"{type(bundle).__name__}.{field.name} " - f"({type(strategy).__name__}) is missing required " - f"'dim' attribute" - ) - if strategy.dim not in (required, Dim.BOTH): - raise ValueError( - f"{type(bundle).__name__}.{field.name} " - f"({type(strategy).__name__}) has dim={strategy.dim}, " - f"incompatible with is_3d={self.is_3d}; " - f"expected {required} or {Dim.BOTH}" - ) + + for field in attrs.fields(type(self.resolver)): + if field.name in SKIP_DIMENSION: + continue + strategy = getattr(self.resolver, field.name) + if not attrs.has(type(strategy)): + continue + if not hasattr(strategy, "dim"): + raise AttributeError( + f"Resolver.{field.name} " + f"({type(strategy).__name__}) is missing required " + f"'dim' attribute" + ) + if strategy.dim not in (required, Dim.BOTH): + raise ValueError( + f"Resolver.{field.name} " + f"({type(strategy).__name__}) has dim={strategy.dim}, " + f"incompatible with is_3d={self.is_3d}; " + f"expected {required} or {Dim.BOTH}" + ) __all__ = [ diff --git a/pooltool/evolution/event_based/detect/__init__.py b/pooltool/evolution/event_based/detect/__init__.py index 54942975..c645826a 100644 --- a/pooltool/evolution/event_based/detect/__init__.py +++ b/pooltool/evolution/event_based/detect/__init__.py @@ -1,39 +1,35 @@ from pooltool.evolution.event_based.detect.ball_ball import ( - BallBallDetection, - BallBallDetectionStrategy, + get_next_ball_ball_2d_event, + get_next_ball_ball_3d_event, ) from pooltool.evolution.event_based.detect.ball_cushion import ( - BallCCushionDetection, - BallCCushionDetectionStrategy, - BallLCushionDetection, - BallLCushionDetectionStrategy, + get_next_ball_circular_cushion_2d_event, + get_next_ball_circular_cushion_3d_event, + get_next_ball_linear_cushion_2d_event, + get_next_ball_linear_cushion_3d_event, ) from pooltool.evolution.event_based.detect.ball_pocket import ( - BallPocketDetection, - BallPocketDetectionStrategy, + get_next_ball_pocket_2d_event, + get_next_ball_pocket_3d_event, ) from pooltool.evolution.event_based.detect.ball_table import ( - BallTableDetection, - BallTableDetectionStrategy, + get_next_ball_table_event, ) from pooltool.evolution.event_based.detect.detector import EventDetector from pooltool.evolution.event_based.detect.stick_ball import ( - StickBallDetection, - StickBallDetectionStrategy, + get_next_stick_ball_event, ) __all__ = [ "EventDetector", - "BallBallDetection", - "BallBallDetectionStrategy", - "BallCCushionDetection", - "BallCCushionDetectionStrategy", - "BallLCushionDetection", - "BallLCushionDetectionStrategy", - "BallPocketDetection", - "BallPocketDetectionStrategy", - "BallTableDetection", - "BallTableDetectionStrategy", - "StickBallDetection", - "StickBallDetectionStrategy", + "get_next_ball_ball_2d_event", + "get_next_ball_ball_3d_event", + "get_next_ball_circular_cushion_2d_event", + "get_next_ball_circular_cushion_3d_event", + "get_next_ball_linear_cushion_2d_event", + "get_next_ball_linear_cushion_3d_event", + "get_next_ball_pocket_2d_event", + "get_next_ball_pocket_3d_event", + "get_next_ball_table_event", + "get_next_stick_ball_event", ] diff --git a/pooltool/evolution/event_based/detect/ball_ball.py b/pooltool/evolution/event_based/detect/ball_ball.py index 8813e5c3..2c3c550b 100644 --- a/pooltool/evolution/event_based/detect/ball_ball.py +++ b/pooltool/evolution/event_based/detect/ball_ball.py @@ -1,93 +1,81 @@ from __future__ import annotations from itertools import combinations -from typing import Protocol -import attrs import numpy as np import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.events import Event, EventType, ball_ball_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.dimensionality import Dim from pooltool.physics.motion.solve import ball_ball_collision_time from pooltool.system.datatypes import System -class BallBallDetectionStrategy(Protocol): - """Ball-ball detection models must satisfy this protocol.""" - - dim: Dim - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... - - -@attrs.define -class BallBallDetection: - """Detects the next ball-ball collision in the system.""" - - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: - cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) - - for ball1, ball2 in combinations(shot.balls.values(), 2): - ball_pair = (ball1.id, ball2.id) - if ball_pair in cache: - continue - - ball1_state = ball1.state - ball1_params = ball1.params - - ball2_state = ball2.state - ball2_params = ball2.params - - if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed: - cache[ball_pair] = np.inf - elif ( - ball1_state.s in const.nontranslating - and ball2_state.s in const.nontranslating - ): - cache[ball_pair] = np.inf - elif ptmath.is_overlapping( - ball1_state.rvw, - ball2_state.rvw, - ball1_params.R, - ball2_params.R, - ): - cache[ball_pair] = shot.t - else: - dtau_E = ball_ball_collision_time( - rvw1=ball1_state.rvw, - rvw2=ball2_state.rvw, - s1=ball1_state.s, - s2=ball2_state.s, - mu1=( - ball1_params.u_s - if ball1_state.s == const.sliding - else ball1_params.u_r - ), - mu2=( - ball2_params.u_s - if ball2_state.s == const.sliding - else ball2_params.u_r - ), - m1=ball1_params.m, - m2=ball2_params.m, - g1=ball1_params.g, - g2=ball2_params.g, - R=ball1_params.R, - ) - cache[ball_pair] = shot.t + dtau_E - - if not cache: - return null_event(np.inf) - - ball_pair = min(cache, key=lambda k: cache[k]) - - return ball_ball_collision( - ball1=shot.balls[ball_pair[0]], - ball2=shot.balls[ball_pair[1]], - time=cache[ball_pair], - ) +def get_next_ball_ball_2d_event(shot: System, collision_cache: CollisionCache) -> Event: + """Detect the next ball-ball collision in 2D mode.""" + cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) + + for ball1, ball2 in combinations(shot.balls.values(), 2): + ball_pair = (ball1.id, ball2.id) + if ball_pair in cache: + continue + + ball1_state = ball1.state + ball1_params = ball1.params + + ball2_state = ball2.state + ball2_params = ball2.params + + if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed: + cache[ball_pair] = np.inf + elif ( + ball1_state.s in const.nontranslating + and ball2_state.s in const.nontranslating + ): + cache[ball_pair] = np.inf + elif ptmath.is_overlapping( + ball1_state.rvw, + ball2_state.rvw, + ball1_params.R, + ball2_params.R, + ): + cache[ball_pair] = shot.t + else: + dtau_E = ball_ball_collision_time( + rvw1=ball1_state.rvw, + rvw2=ball2_state.rvw, + s1=ball1_state.s, + s2=ball2_state.s, + mu1=( + ball1_params.u_s + if ball1_state.s == const.sliding + else ball1_params.u_r + ), + mu2=( + ball2_params.u_s + if ball2_state.s == const.sliding + else ball2_params.u_r + ), + m1=ball1_params.m, + m2=ball2_params.m, + g1=ball1_params.g, + g2=ball2_params.g, + R=ball1_params.R, + ) + cache[ball_pair] = shot.t + dtau_E + + if not cache: + return null_event(np.inf) + + ball_pair = min(cache, key=lambda k: cache[k]) + + return ball_ball_collision( + ball1=shot.balls[ball_pair[0]], + ball2=shot.balls[ball_pair[1]], + time=cache[ball_pair], + ) + + +def get_next_ball_ball_3d_event(shot: System, collision_cache: CollisionCache) -> Event: + raise NotImplementedError("3D ball-ball detection has not been vendored yet") diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py index 63954cc8..d67cbc65 100644 --- a/pooltool/evolution/event_based/detect/ball_cushion.py +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -1,8 +1,5 @@ from __future__ import annotations -from typing import Protocol - -import attrs import numpy as np import pooltool.constants as const @@ -14,7 +11,6 @@ null_event, ) from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.dimensionality import Dim from pooltool.physics.motion.solve import ( ball_circular_cushion_collision_time, ball_linear_cushion_collision_time, @@ -22,117 +18,111 @@ from pooltool.system.datatypes import System -class BallLCushionDetectionStrategy(Protocol): - """Ball-vs-linear-cushion-segment detection models must satisfy this protocol.""" - - dim: Dim - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... - - -@attrs.define -class BallLCushionDetection: - """Detects the next ball-vs-linear-cushion-segment collision in the system.""" - - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: - if not shot.table.has_linear_cushions: - return null_event(np.inf) - - cache = collision_cache.times.setdefault(EventType.BALL_LINEAR_CUSHION, {}) - - for ball in shot.balls.values(): - state = ball.state - params = ball.params - - for cushion in shot.table.cushion_segments.linear.values(): - obj_ids = (ball.id, cushion.id) - - if obj_ids in cache: - continue - - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue - - dtau_E = ball_linear_cushion_collision_time( - rvw=state.rvw, - s=state.s, - lx=cushion.lx, - ly=cushion.ly, - l0=cushion.l0, - p1=cushion.p1, - p2=cushion.p2, - direction=cushion.direction, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - - cache[obj_ids] = shot.t + dtau_E - - obj_ids = min(cache, key=lambda k: cache[k]) - - return ball_linear_cushion_collision( - ball=shot.balls[obj_ids[0]], - cushion=shot.table.cushion_segments.linear[obj_ids[1]], - time=cache[obj_ids], - ) - - -class BallCCushionDetectionStrategy(Protocol): - """Ball-vs-circular-cushion-segment detection models must satisfy this protocol.""" - - dim: Dim - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... - - -@attrs.define -class BallCCushionDetection: - """Detects the next ball-vs-circular-cushion-segment collision in the system.""" - - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: - if not shot.table.has_circular_cushions: - return null_event(np.inf) - - cache = collision_cache.times.setdefault(EventType.BALL_CIRCULAR_CUSHION, {}) - - for ball in shot.balls.values(): - state = ball.state - params = ball.params - - for cushion in shot.table.cushion_segments.circular.values(): - obj_ids = (ball.id, cushion.id) - - if obj_ids in cache: - continue - - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue - - dtau_E = ball_circular_cushion_collision_time( - rvw=state.rvw, - s=state.s, - a=cushion.a, - b=cushion.b, - r=cushion.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - cache[obj_ids] = shot.t + dtau_E - - ball_id, cushion_id = min(cache, key=lambda k: cache[k]) - - return ball_circular_cushion_collision( - ball=shot.balls[ball_id], - cushion=shot.table.cushion_segments.circular[cushion_id], - time=cache[(ball_id, cushion_id)], - ) +def get_next_ball_linear_cushion_2d_event( + shot: System, collision_cache: CollisionCache +) -> Event: + """Detect the next ball-vs-linear-cushion collision in 2D mode.""" + if not shot.table.has_linear_cushions: + return null_event(np.inf) + + cache = collision_cache.times.setdefault(EventType.BALL_LINEAR_CUSHION, {}) + + for ball in shot.balls.values(): + state = ball.state + params = ball.params + + for cushion in shot.table.cushion_segments.linear.values(): + obj_ids = (ball.id, cushion.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + dtau_E = ball_linear_cushion_collision_time( + rvw=state.rvw, + s=state.s, + lx=cushion.lx, + ly=cushion.ly, + l0=cushion.l0, + p1=cushion.p1, + p2=cushion.p2, + direction=cushion.direction, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + + cache[obj_ids] = shot.t + dtau_E + + obj_ids = min(cache, key=lambda k: cache[k]) + + return ball_linear_cushion_collision( + ball=shot.balls[obj_ids[0]], + cushion=shot.table.cushion_segments.linear[obj_ids[1]], + time=cache[obj_ids], + ) + + +def get_next_ball_linear_cushion_3d_event( + shot: System, collision_cache: CollisionCache +) -> Event: + raise NotImplementedError( + "3D ball-linear-cushion detection has not been vendored yet" + ) + + +def get_next_ball_circular_cushion_2d_event( + shot: System, collision_cache: CollisionCache +) -> Event: + """Detect the next ball-vs-circular-cushion collision in 2D mode.""" + if not shot.table.has_circular_cushions: + return null_event(np.inf) + + cache = collision_cache.times.setdefault(EventType.BALL_CIRCULAR_CUSHION, {}) + + for ball in shot.balls.values(): + state = ball.state + params = ball.params + + for cushion in shot.table.cushion_segments.circular.values(): + obj_ids = (ball.id, cushion.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + dtau_E = ball_circular_cushion_collision_time( + rvw=state.rvw, + s=state.s, + a=cushion.a, + b=cushion.b, + r=cushion.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + cache[obj_ids] = shot.t + dtau_E + + ball_id, cushion_id = min(cache, key=lambda k: cache[k]) + + return ball_circular_cushion_collision( + ball=shot.balls[ball_id], + cushion=shot.table.cushion_segments.circular[cushion_id], + time=cache[(ball_id, cushion_id)], + ) + + +def get_next_ball_circular_cushion_3d_event( + shot: System, collision_cache: CollisionCache +) -> Event: + raise NotImplementedError( + "3D ball-circular-cushion detection has not been vendored yet" + ) diff --git a/pooltool/evolution/event_based/detect/ball_pocket.py b/pooltool/evolution/event_based/detect/ball_pocket.py index 66ebcf4a..1424f873 100644 --- a/pooltool/evolution/event_based/detect/ball_pocket.py +++ b/pooltool/evolution/event_based/detect/ball_pocket.py @@ -1,69 +1,60 @@ from __future__ import annotations -from typing import Protocol - -import attrs import numpy as np import pooltool.constants as const from pooltool.events import Event, EventType, ball_pocket_collision, null_event from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.dimensionality import Dim from pooltool.physics.motion.solve import ball_pocket_collision_time from pooltool.system.datatypes import System -class BallPocketDetectionStrategy(Protocol): - """Ball-pocket detection models must satisfy this protocol.""" - - dim: Dim - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... - - -@attrs.define -class BallPocketDetection: - """Detects the next ball-pocket collision in the system.""" +def get_next_ball_pocket_2d_event( + shot: System, collision_cache: CollisionCache +) -> Event: + """Detect the next ball-pocket collision in 2D mode.""" + if not shot.table.has_pockets: + return null_event(np.inf) - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: - if not shot.table.has_pockets: - return null_event(np.inf) + for ball in shot.balls.values(): + state = ball.state + params = ball.params - cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) + for pocket in shot.table.pockets.values(): + obj_ids = (ball.id, pocket.id) - for ball in shot.balls.values(): - state = ball.state - params = ball.params + if obj_ids in cache: + continue - for pocket in shot.table.pockets.values(): - obj_ids = (ball.id, pocket.id) + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue - if obj_ids in cache: - continue + dtau_E = ball_pocket_collision_time( + rvw=state.rvw, + s=state.s, + a=pocket.a, + b=pocket.b, + r=pocket.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + cache[obj_ids] = shot.t + dtau_E - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue + ball_id, pocket_id = min(cache, key=lambda k: cache[k]) - dtau_E = ball_pocket_collision_time( - rvw=state.rvw, - s=state.s, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - cache[obj_ids] = shot.t + dtau_E + return ball_pocket_collision( + ball=shot.balls[ball_id], + pocket=shot.table.pockets[pocket_id], + time=cache[(ball_id, pocket_id)], + ) - ball_id, pocket_id = min(cache, key=lambda k: cache[k]) - return ball_pocket_collision( - ball=shot.balls[ball_id], - pocket=shot.table.pockets[pocket_id], - time=cache[(ball_id, pocket_id)], - ) +def get_next_ball_pocket_3d_event( + shot: System, collision_cache: CollisionCache +) -> Event: + raise NotImplementedError("3D ball-pocket detection has not been vendored yet") diff --git a/pooltool/evolution/event_based/detect/ball_table.py b/pooltool/evolution/event_based/detect/ball_table.py index cba3ae2f..df74fb54 100644 --- a/pooltool/evolution/event_based/detect/ball_table.py +++ b/pooltool/evolution/event_based/detect/ball_table.py @@ -1,47 +1,34 @@ from __future__ import annotations -from typing import Protocol - -import attrs - from pooltool.events import Event, EventType, ball_table_collision from pooltool.evolution.event_based.cache import CollisionCache from pooltool.physics.motion.solve import ball_table_collision_time from pooltool.system.datatypes import System -class BallTableDetectionStrategy(Protocol): - """Ball-table detection models must satisfy this protocol. +def get_next_ball_table_event(shot: System, collision_cache: CollisionCache) -> Event: + """Detect the next ball-table collision (airborne ball landing). - Unlike the other detection-strategy protocols, this one does not declare a - ``dim`` attribute. + Only invoked when ``EventDetector.is_3d`` is True — ball-table events are + a 3D-only concept. """ + cache = collision_cache.times.setdefault(EventType.BALL_TABLE, {}) + + for ball in shot.balls.values(): + obj_ids = (ball.id,) + if obj_ids in cache: + continue + dtau_E = ball_table_collision_time( + rvw=ball.state.rvw, + s=ball.state.s, + g=ball.params.g, + R=ball.params.R, + ) + cache[obj_ids] = shot.t + dtau_E - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... - - -@attrs.define -class BallTableDetection: - """Detects the next ball-table collision in the system.""" - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: - cache = collision_cache.times.setdefault(EventType.BALL_TABLE, {}) - - for ball in shot.balls.values(): - obj_ids = (ball.id,) - if obj_ids in cache: - continue - dtau_E = ball_table_collision_time( - rvw=ball.state.rvw, - s=ball.state.s, - g=ball.params.g, - R=ball.params.R, - ) - cache[obj_ids] = shot.t + dtau_E - - obj_ids = min(cache, key=lambda k: cache[k]) + obj_ids = min(cache, key=lambda k: cache[k]) - return ball_table_collision( - ball=shot.balls[obj_ids[0]], - time=cache[obj_ids], - ) + return ball_table_collision( + ball=shot.balls[obj_ids[0]], + time=cache[obj_ids], + ) diff --git a/pooltool/evolution/event_based/detect/detector.py b/pooltool/evolution/event_based/detect/detector.py index ef1cf56e..57f0e655 100644 --- a/pooltool/evolution/event_based/detect/detector.py +++ b/pooltool/evolution/event_based/detect/detector.py @@ -6,14 +6,26 @@ import pooltool.ptmath as ptmath from pooltool.events import Event, EventType, null_event from pooltool.evolution.event_based.cache import CollisionCache, TransitionCache -from pooltool.evolution.event_based.detect.ball_ball import BallBallDetection +from pooltool.evolution.event_based.detect.ball_ball import ( + get_next_ball_ball_2d_event, + get_next_ball_ball_3d_event, +) from pooltool.evolution.event_based.detect.ball_cushion import ( - BallCCushionDetection, - BallLCushionDetection, + get_next_ball_circular_cushion_2d_event, + get_next_ball_circular_cushion_3d_event, + get_next_ball_linear_cushion_2d_event, + get_next_ball_linear_cushion_3d_event, +) +from pooltool.evolution.event_based.detect.ball_pocket import ( + get_next_ball_pocket_2d_event, + get_next_ball_pocket_3d_event, +) +from pooltool.evolution.event_based.detect.ball_table import ( + get_next_ball_table_event, +) +from pooltool.evolution.event_based.detect.stick_ball import ( + get_next_stick_ball_event, ) -from pooltool.evolution.event_based.detect.ball_pocket import BallPocketDetection -from pooltool.evolution.event_based.detect.ball_table import BallTableDetection -from pooltool.evolution.event_based.detect.stick_ball import StickBallDetection from pooltool.physics.utils import get_ball_energy from pooltool.system.datatypes import System @@ -86,42 +98,18 @@ def _get_event_priority(event: Event, shot: System) -> tuple[int, float]: @attrs.define class EventDetector: - """Bundles per-event-type detection strategies. + """Orchestrates per-event-type detection. - Fields are typed as the concrete strategy class rather than the corresponding - ``*DetectionStrategy`` protocol. This keeps cattrs structuring trivial — cattrs - can structure into a concrete attrs class natively but cannot resolve a Protocol - without a discriminator. When a second implementation is added for a given event - type, this field type should be widened (e.g. to a union of concrete classes, or - to the protocol with a registry + structure hook that dispatches on a tag, in - the style of :class:`pooltool.physics.Resolver`). + The 2D-vs-3D branching for forked event types happens here, in ``get_next_event``. + The per-event-type ``get_next_*_event`` functions are each mode-pure. Attributes: - stick_ball: - Strategy for detecting the next stick-ball collision. - ball_ball: - Strategy for detecting the next ball-ball collision. - ball_linear_cushion: - Strategy for detecting the next ball-vs-linear-cushion-segment collision. - ball_circular_cushion: - Strategy for detecting the next ball-vs-circular-cushion-segment collision. - ball_pocket: - Strategy for detecting the next ball-pocket collision. - ball_table: - Strategy for detecting the next ball-table collision (airborne ball - landing on the table surface). + is_3d: + Whether to dispatch to 3D detection variants. Set by ``SimulationEngine`` at + construction. """ - stick_ball: StickBallDetection = attrs.field(factory=StickBallDetection) - ball_ball: BallBallDetection = attrs.field(factory=BallBallDetection) - ball_linear_cushion: BallLCushionDetection = attrs.field( - factory=BallLCushionDetection - ) - ball_circular_cushion: BallCCushionDetection = attrs.field( - factory=BallCCushionDetection - ) - ball_pocket: BallPocketDetection = attrs.field(factory=BallPocketDetection) - ball_table: BallTableDetection = attrs.field(factory=BallTableDetection) + is_3d: bool = False @classmethod def default(cls) -> EventDetector: @@ -146,20 +134,36 @@ def get_next_event( candidates: list[Event] = [] - # Stick-ball collisions only occur at t=0 (shot initiation), so we skip this - # check after the first timestep as an optimization. Other collision types are - # always checked because they can occur at any time during simulation. Note: - # even at t=0, we still call the remaining detection strategies to fully - # populate the collision cache, which is needed by debug/introspection tools. + # Stick-ball collisions only occur at t=0 (shot initiation), so we skip + # this check after the first timestep as an optimization. Other collision + # types are always checked because they can occur at any time during + # simulation. Note: even at t=0, we still call the remaining detection + # functions to fully populate the collision cache, which is needed by + # debug/introspection tools. if shot.t == 0: - candidates.append(self.stick_ball.get_next(shot, collision_cache)) + candidates.append(get_next_stick_ball_event(shot, collision_cache)) candidates.append(transition_cache.get_next()) - candidates.append(self.ball_ball.get_next(shot, collision_cache)) - candidates.append(self.ball_circular_cushion.get_next(shot, collision_cache)) - candidates.append(self.ball_linear_cushion.get_next(shot, collision_cache)) - candidates.append(self.ball_pocket.get_next(shot, collision_cache)) - candidates.append(self.ball_table.get_next(shot, collision_cache)) + + if self.is_3d: + candidates.append(get_next_ball_ball_3d_event(shot, collision_cache)) + candidates.append( + get_next_ball_circular_cushion_3d_event(shot, collision_cache) + ) + candidates.append( + get_next_ball_linear_cushion_3d_event(shot, collision_cache) + ) + candidates.append(get_next_ball_pocket_3d_event(shot, collision_cache)) + candidates.append(get_next_ball_table_event(shot, collision_cache)) + else: + candidates.append(get_next_ball_ball_2d_event(shot, collision_cache)) + candidates.append( + get_next_ball_circular_cushion_2d_event(shot, collision_cache) + ) + candidates.append( + get_next_ball_linear_cushion_2d_event(shot, collision_cache) + ) + candidates.append(get_next_ball_pocket_2d_event(shot, collision_cache)) min_time = min(event.time for event in candidates) diff --git a/pooltool/evolution/event_based/detect/stick_ball.py b/pooltool/evolution/event_based/detect/stick_ball.py index b9c65603..c9ec5c42 100644 --- a/pooltool/evolution/event_based/detect/stick_ball.py +++ b/pooltool/evolution/event_based/detect/stick_ball.py @@ -1,54 +1,37 @@ from __future__ import annotations -from typing import Protocol - -import attrs import numpy as np from pooltool.events import Event, EventType, stick_ball_collision from pooltool.evolution.event_based._utils import _system_has_energy from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.physics.dimensionality import Dim from pooltool.system.datatypes import System -class StickBallDetectionStrategy(Protocol): - """Stick-ball detection models must satisfy this protocol.""" - - dim: Dim - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... +def get_next_stick_ball_event(shot: System, collision_cache: CollisionCache) -> Event: + """Detect the next stick-ball collision. - -@attrs.define -class StickBallDetection: - """Stick-ball collision detection. - - Stick-ball events fire only at t=0, when the system is at rest and a cue strike is - queued (V0 > 0). + Stick-ball events fire only at t=0, when the system is at rest and a cue + strike is queued (V0 > 0). """ + cache = collision_cache.times.setdefault(EventType.STICK_BALL, {}) - dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) - - def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: - cache = collision_cache.times.setdefault(EventType.STICK_BALL, {}) - - obj_ids = (shot.cue.id, shot.cue.cue_ball_id) - - if obj_ids in cache: - return stick_ball_collision( - stick=shot.cue, - ball=shot.balls[shot.cue.cue_ball_id], - time=cache[obj_ids], - ) - - if shot.t == 0 and not _system_has_energy(shot) and shot.cue.V0 > 0: - cache[obj_ids] = 0.0 - else: - cache[obj_ids] = np.inf + obj_ids = (shot.cue.id, shot.cue.cue_ball_id) + if obj_ids in cache: return stick_ball_collision( stick=shot.cue, ball=shot.balls[shot.cue.cue_ball_id], time=cache[obj_ids], ) + + if shot.t == 0 and not _system_has_energy(shot) and shot.cue.V0 > 0: + cache[obj_ids] = 0.0 + else: + cache[obj_ids] = np.inf + + return stick_ball_collision( + stick=shot.cue, + ball=shot.balls[shot.cue.cue_ball_id], + time=cache[obj_ids], + ) diff --git a/pooltool/evolution/event_based/introspection.py b/pooltool/evolution/event_based/introspection.py index a641c72e..0bd7fb14 100644 --- a/pooltool/evolution/event_based/introspection.py +++ b/pooltool/evolution/event_based/introspection.py @@ -48,6 +48,11 @@ def _get_collision_events_from_cache( system: System, cache: CollisionCache ) -> list[Event]: + # TODO: BALL_TABLE entries in the cache are not reconstructed here. In 2D + # mode this is harmless (the detector doesn't populate the BALL_TABLE + # bucket). When 3D activation lands, prospective BALL_TABLE events will be + # silently missed from get_prospective_events(). Add a branch that builds + # ball_table_collision(ball, time) from each (ball_id,) key. events = [] if EventType.BALL_BALL in cache.times: diff --git a/tests/evolution/event_based/test_ball_table.py b/tests/evolution/event_based/test_ball_table.py index f3a73cc5..b07bddb5 100644 --- a/tests/evolution/event_based/test_ball_table.py +++ b/tests/evolution/event_based/test_ball_table.py @@ -4,7 +4,9 @@ import pooltool.constants as const from pooltool.events import EventType from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.evolution.event_based.detect.ball_table import BallTableDetection +from pooltool.evolution.event_based.detect.ball_table import ( + get_next_ball_table_event, +) from pooltool.physics.utils import get_airborne_time from pooltool.system.datatypes import System @@ -16,7 +18,7 @@ def system() -> System: def test_no_airborne_balls_returns_inf_time(system: System): """In a default 2D scene no ball is airborne, so the emitted event has time=inf.""" - event = BallTableDetection().get_next(system, CollisionCache()) + event = get_next_ball_table_event(system, CollisionCache()) assert event.event_type == EventType.BALL_TABLE assert event.time == np.inf @@ -28,7 +30,7 @@ def test_airborne_ball_returns_finite_time(system: System): ball.state.rvw[1, 2] = 0.0 ball.state.s = const.airborne - event = BallTableDetection().get_next(system, CollisionCache()) + event = get_next_ball_table_event(system, CollisionCache()) expected = get_airborne_time(ball.state.rvw, ball.params.R, ball.params.g) assert event.event_type == EventType.BALL_TABLE @@ -50,7 +52,7 @@ def test_returns_soonest_ball(system: System): low.state.rvw[1, 2] = 0.0 low.state.s = const.airborne - event = BallTableDetection().get_next(system, CollisionCache()) + event = get_next_ball_table_event(system, CollisionCache()) assert event.event_type == EventType.BALL_TABLE assert event.ids[0] == low.id diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index ea627052..3baf6bcb 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -9,7 +9,10 @@ from pooltool.events import EventType, ball_ball_collision, ball_pocket_collision from pooltool.evolution.event_based._utils import _system_has_energy from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.evolution.event_based.detect import BallBallDetection, EventDetector +from pooltool.evolution.event_based.detect import ( + EventDetector, + get_next_ball_ball_2d_event, +) from pooltool.evolution.event_based.simulate import simulate from pooltool.objects import Ball, BilliardTableSpecs, Cue, Table from pooltool.objects.ball.params import BallParams @@ -449,7 +452,7 @@ def test_ball_ball_collision_for_intersecting_balls(): _assert_rolling(system.balls["cue"].state.rvw, system.balls["cue"].params.R) assert _DETECTOR.get_next_event(system).event_type == EventType.BALL_BALL - collision_event = BallBallDetection().get_next(system, CollisionCache()) + collision_event = get_next_ball_ball_2d_event(system, CollisionCache()) assert collision_event.time != np.inf assert collision_event.time == 0 diff --git a/tests/evolution/test_engine.py b/tests/evolution/test_engine.py index 03076132..83c9173e 100644 --- a/tests/evolution/test_engine.py +++ b/tests/evolution/test_engine.py @@ -2,31 +2,24 @@ import pytest from pooltool.evolution.engine import SimulationEngine -from pooltool.evolution.event_based.detect import EventDetector from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.resolver import Resolver -def _patch_all_dims( - resolver: Resolver, - detector: EventDetector, - dim_value: Dim, -) -> None: - """Set every strategy's dim to dim_value, in place.""" - for bundle in (resolver, detector): - for field in attrs.fields(type(bundle)): - strategy = getattr(bundle, field.name) - if hasattr(strategy, "dim"): - strategy.dim = dim_value +def _patch_resolver_dims(resolver: Resolver, dim_value: Dim) -> None: + """Set every resolver strategy's dim to dim_value, in place.""" + for field in attrs.fields(type(resolver)): + strategy = getattr(resolver, field.name) + if hasattr(strategy, "dim"): + strategy.dim = dim_value @pytest.fixture def engine_3d() -> SimulationEngine: - """A SimulationEngine constructed with every strategy patched to Dim.THREE.""" + """A SimulationEngine constructed with every resolver strategy patched to Dim.THREE.""" resolver = SimulationEngine().resolver - detector = SimulationEngine().detector - _patch_all_dims(resolver, detector, Dim.THREE) - return SimulationEngine(resolver=resolver, detector=detector, is_3d=True) + _patch_resolver_dims(resolver, Dim.THREE) + return SimulationEngine(resolver=resolver, is_3d=True) def test_default_engine_constructs(): @@ -36,6 +29,7 @@ def test_default_engine_constructs(): def test_3d_engine_constructs(engine_3d: SimulationEngine): assert engine_3d.is_3d is True + assert engine_3d.detector.is_3d is True def test_3d_engine_with_all_2d_strategies_raises(): @@ -44,7 +38,7 @@ def test_3d_engine_with_all_2d_strategies_raises(): def test_validation_error_identifies_offending_strategy(): - with pytest.raises(ValueError, match=r"Resolver\.|EventDetector\."): + with pytest.raises(ValueError, match=r"Resolver\."): SimulationEngine(is_3d=True) @@ -65,11 +59,7 @@ def test_3d_engine_rejects_one_dim_two_strategy(engine_3d: SimulationEngine): engine_3d.resolver.ball_ball.dim = Dim.TWO with pytest.raises(ValueError, match=r"ball_ball.*incompatible with is_3d=True"): - SimulationEngine( - resolver=engine_3d.resolver, - detector=engine_3d.detector, - is_3d=True, - ) + SimulationEngine(resolver=engine_3d.resolver, is_3d=True) def test_dim_both_strategy_accepted_in_2d(): @@ -80,21 +70,20 @@ def test_dim_both_strategy_accepted_in_2d(): def test_dim_both_strategy_accepted_in_3d(engine_3d: SimulationEngine): engine_3d.resolver.ball_ball.dim = Dim.BOTH - SimulationEngine( - resolver=engine_3d.resolver, - detector=engine_3d.detector, - is_3d=True, - ) + SimulationEngine(resolver=engine_3d.resolver, is_3d=True) def test_ball_table_exempt_from_dim_validation(): - """ball_table strategies don't carry a `dim` attribute. The validator - skips both Resolver.ball_table and EventDetector.ball_table fields in - either mode via SKIP_DIMENSION.""" + """Ball-table resolver strategies don't carry a `dim` attribute. The + validator skips this field in either mode via SKIP_DIMENSION.""" resolver = SimulationEngine().resolver - detector = SimulationEngine().detector assert not hasattr(resolver.ball_table, "dim") - assert not hasattr(detector.ball_table, "dim") - SimulationEngine(resolver=resolver, detector=detector, is_3d=False) + SimulationEngine(resolver=resolver, is_3d=False) + + +def test_detector_is_not_constructor_passable(): + """``detector`` is init=False on SimulationEngine.""" + with pytest.raises(TypeError): + SimulationEngine(detector="anything") # type: ignore[call-arg]