diff --git a/docs/examples/30_degree_rule.ipynb b/docs/examples/30_degree_rule.ipynb index 5f62e5f3..15949549 100644 --- a/docs/examples/30_degree_rule.ipynb +++ b/docs/examples/30_degree_rule.ipynb @@ -211,7 +211,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": null, "id": "d29a2bda", "metadata": { "execution": { @@ -221,25 +221,8 @@ "shell.execute_reply": "2026-03-15T06:51:53.253495Z" } }, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "System simulated: True\n" - ] - } - ], - "source": [ - "# Create a default physics engine, then overwrite ball-ball model with frictionless, elastic model.\n", - "engine = pt.physics.PhysicsEngine()\n", - "engine.resolver.ball_ball = pt.physics.ball_ball_models[pt.physics.BallBallModel.FRICTIONLESS_ELASTIC]()\n", - "\n", - "pt.simulate(system, engine=engine, inplace=True)\n", - "pt.continuize(system, dt=0.01, inplace=True)\n", - "\n", - "print(f\"System simulated: {system.simulated}\")" - ] + "outputs": [], + "source": "# Create a default simulation engine, then overwrite ball-ball model with frictionless, elastic model.\nengine = pt.evolution.SimulationEngine()\nengine.resolver.ball_ball = pt.physics.ball_ball_models[pt.physics.BallBallModel.FRICTIONLESS_ELASTIC]()\n\npt.simulate(system, engine=engine, inplace=True)\npt.continuize(system, dt=0.01, inplace=True)\n\nprint(f\"System simulated: {system.simulated}\")" }, { "cell_type": "markdown", @@ -1371,4 +1354,4 @@ }, "nbformat": 4, "nbformat_minor": 5 -} +} \ No newline at end of file diff --git a/docs/include_exclude.py b/docs/include_exclude.py index 3ef31beb..b6f3f514 100644 --- a/docs/include_exclude.py +++ b/docs/include_exclude.py @@ -29,6 +29,7 @@ "pooltool.ruleset.snooker", # API: pooltool.physics "pooltool.physics.evolve", + "pooltool.physics.motion", "pooltool.physics.resolve", ], "module": [ @@ -56,8 +57,10 @@ "pooltool.system.datatypes", # API: pooltool.game "pooltool.game.datatypes", + # API: pooltool.evolution + "pooltool.evolution.engine", # API: pooltool.physics - "pooltool.physics.engine", + "pooltool.physics.utils", # API: pooltool.physics.resolve "pooltool.physics.resolve.models", "pooltool.physics.resolve.resolver", diff --git a/docs/meta/developer_guide.md b/docs/meta/developer_guide.md index 16a6a6b2..91473ed6 100644 --- a/docs/meta/developer_guide.md +++ b/docs/meta/developer_guide.md @@ -163,7 +163,7 @@ To add an inlaid dropdown signature: Add optional text here ```{eval-rst} -.. autoclass:: pooltool.physics.PhysicsEngine +.. autoclass:: pooltool.evolution.SimulationEngine :noindex: ``` ::: diff --git a/docs/resources/custom_physics.md b/docs/resources/custom_physics.md index aefe37cf..2fae0b50 100644 --- a/docs/resources/custom_physics.md +++ b/docs/resources/custom_physics.md @@ -181,14 +181,28 @@ class CoreBallLCushionCollision(ABC): With these, we can draft our template by following the example code: [5ecddb2c0c010e3f058e666fd5a7fc1f10117638](https://github.com/ekiefl/pooltool/commit/5ecddb2c0c010e3f058e666fd5a7fc1f10117638) -It's just missing two things. First, the class must be an attrs class. Pooltool requires that all the resolver models are [attrs](https://www.attrs.org/en/stable/) classes. If you've never used attrs before, stick close to the example and you'll have no problems. Second, the class must have an attribute called `model`, and this attribute should be the Enum member that you had previously added to `pooltool/physics/resolve/models.py`. +It's just missing three things. First, the class must be an attrs class. Pooltool requires that all the resolver models are [attrs](https://www.attrs.org/en/stable/) classes. If you've never used attrs before, stick close to the example and you'll have no problems. Second, the class must have an attribute called `model`, and this attribute should be the Enum member that you had previously added to `pooltool/physics/resolve/models.py`. Third, the class must have an attribute called `dim` that declares the simulation dimensionality your model supports. -To apply these changes, follow the example code: [9c12a6efa2b9d201d8cedfc75b1a83b8134dd7ec](https://github.com/ekiefl/pooltool/commit/9c12a6efa2b9d201d8cedfc75b1a83b8134dd7ec). Since I added `UNREALISTIC` to the `BallLCushionModel` model, I added the following attribute to my class: +To apply the `model` requirement, follow the example code: [9c12a6efa2b9d201d8cedfc75b1a83b8134dd7ec](https://github.com/ekiefl/pooltool/commit/9c12a6efa2b9d201d8cedfc75b1a83b8134dd7ec). Since I added `UNREALISTIC` to the `BallLCushionModel` model, I added the following attribute to my class: ```python model: BallLCushionModel = attrs.field(default=BallLCushionModel.UNREALISTIC, init=False, repr=False) ``` +For the `dim` requirement, add a `dim` field declaring whether your model is safe in 2D, 3D, or both: + +```python +from pooltool.physics.dimensionality import Dim + +dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) +``` + +`Dim` is a capability declaration consumed at engine construction: + +- `Dim.TWO` — your model is safe only when [](#pooltool.evolution.engine.SimulationEngine)'s `is_3d` is `False`. Use this if your model assumes balls are on the table surface (z=R, vz=0). +- `Dim.THREE` — your model is safe only when `is_3d` is `True`. Use this if your model assumes 3D ball state (e.g. it produces or handles airborne balls). +- `Dim.BOTH` — your model behaves identically in either mode. This is a strong promise: a `Dim.BOTH` model doesn't care or know whether it's handling a 2D/3D simulation. + Great, now we are done with the boilerplate code. But `resolve` currently does *nothing*, it just returns what is handed to it. Let's change that. #### Implement the logic diff --git a/pooltool/ani/hud.py b/pooltool/ani/hud.py index 38a47bab..8d4777c1 100644 --- a/pooltool/ani/hud.py +++ b/pooltool/ani/hud.py @@ -20,7 +20,7 @@ from pooltool.ani.globals import Global from pooltool.objects.ball.datatypes import Ball, BallParams from pooltool.objects.cue.datatypes import Cue, CueSpecs -from pooltool.ptmath.utils import tip_center_offset +from pooltool.physics.utils import tip_center_offset from pooltool.ruleset.datatypes import BallInHandOptions from pooltool.utils import panda_path from pooltool.utils.strenum import StrEnum, auto diff --git a/pooltool/ani/modes/aim.py b/pooltool/ani/modes/aim.py index a0ff6778..477f4fc1 100644 --- a/pooltool/ani/modes/aim.py +++ b/pooltool/ani/modes/aim.py @@ -22,7 +22,8 @@ from pooltool.ani.mouse import MouseMode, mouse from pooltool.ani.scene import visual from pooltool.config import settings -from pooltool.ptmath.utils import norm2d, tip_contact_offset +from pooltool.physics.utils import tip_contact_offset +from pooltool.ptmath.utils import norm2d from pooltool.system.datatypes import multisystem diff --git a/pooltool/ani/modes/view.py b/pooltool/ani/modes/view.py index e345036e..31411a0f 100755 --- a/pooltool/ani/modes/view.py +++ b/pooltool/ani/modes/view.py @@ -21,7 +21,8 @@ from pooltool.ani.mouse import MouseMode, mouse from pooltool.ani.scene import visual from pooltool.config import settings -from pooltool.ptmath.utils import norm2d, tip_contact_offset +from pooltool.physics.utils import tip_contact_offset +from pooltool.ptmath.utils import norm2d from pooltool.system.datatypes import multisystem diff --git a/pooltool/evolution/__init__.py b/pooltool/evolution/__init__.py index 1791f112..3a98993b 100644 --- a/pooltool/evolution/__init__.py +++ b/pooltool/evolution/__init__.py @@ -1,9 +1,11 @@ """Shot evolution algorithm routines and utilities""" from pooltool.evolution.continuous import continuize, interpolate_ball_states +from pooltool.evolution.engine import SimulationEngine from pooltool.evolution.event_based.simulate import simulate __all__ = [ + "SimulationEngine", "continuize", "simulate", "interpolate_ball_states", diff --git a/pooltool/evolution/engine.py b/pooltool/evolution/engine.py new file mode 100644 index 00000000..09f8ec13 --- /dev/null +++ b/pooltool/evolution/engine.py @@ -0,0 +1,62 @@ +"""The simulation engine of pooltool""" + +from __future__ import annotations + +import attrs + +from pooltool.evolution.event_based.detect import EventDetector +from pooltool.physics.dimensionality import Dim +from pooltool.physics.resolve import Resolver + + +@attrs.define +class SimulationEngine: + """A pluggable bundle of strategies used by the simulator. + + Holds the strategies that define how a simulation is carried out: how events are + detected and how they are resolved. The simulator is handed an instance of this + class and routes work to its components. + + Attributes: + resolver: + The strategy responsible for resolving events. + detector: + The strategy responsible for detecting the next event. + is_3d: + Whether the simulation supports the airborne motion state and ball-table + events. Validated at construction against the dimensionality capability + (``dim``) of every bundled strategy in ``resolver`` and ``detector``. + """ + + resolver: Resolver = attrs.field(factory=Resolver.default) + detector: EventDetector = attrs.field(factory=EventDetector.default) + is_3d: bool = False + + def __attrs_post_init__(self) -> None: + self._validate_dimensionality() + + def _validate_dimensionality(self) -> None: + required = Dim.THREE if self.is_3d else Dim.TWO + for bundle in (self.resolver, self.detector): + for field in attrs.fields(type(bundle)): + strategy = getattr(bundle, field.name) + if not attrs.has(type(strategy)): + continue + if not hasattr(strategy, "dim"): + raise AttributeError( + f"{type(bundle).__name__}.{field.name} " + f"({type(strategy).__name__}) is missing required " + f"'dim' attribute" + ) + if strategy.dim not in (required, Dim.BOTH): + raise ValueError( + f"{type(bundle).__name__}.{field.name} " + f"({type(strategy).__name__}) has dim={strategy.dim}, " + f"incompatible with is_3d={self.is_3d}; " + f"expected {required} or {Dim.BOTH}" + ) + + +__all__ = [ + "SimulationEngine", +] diff --git a/pooltool/evolution/event_based/_utils.py b/pooltool/evolution/event_based/_utils.py new file mode 100644 index 00000000..851fa773 --- /dev/null +++ b/pooltool/evolution/event_based/_utils.py @@ -0,0 +1,21 @@ +from __future__ import annotations + +from pooltool.physics.utils import get_ball_energy +from pooltool.system.datatypes import System + + +def _system_has_energy(system: System) -> bool: + """Return True if any ball in the system has kinetic energy. + + Cue energy (e.g. ``system.cue.V0 > 0``) does not count. + """ + return any( + bool( + get_ball_energy( + ball.state.rvw, + ball.params.R, + ball.params.m, + ) + ) + for ball in system.balls.values() + ) diff --git a/pooltool/evolution/event_based/cache.py b/pooltool/evolution/event_based/cache.py index 1123d9d5..b0dba780 100644 --- a/pooltool/evolution/event_based/cache.py +++ b/pooltool/evolution/event_based/cache.py @@ -6,7 +6,6 @@ import numpy as np import pooltool.constants as const -import pooltool.ptmath as ptmath from pooltool.events import ( AgentType, Event, @@ -19,6 +18,7 @@ ) from pooltool.events.utils import event_type_to_ball_indices from pooltool.objects.ball.datatypes import Ball +from pooltool.physics.utils import get_roll_time, get_slide_time, get_spin_time from pooltool.serialize import SerializeFormat, conversion from pooltool.system.datatypes import System @@ -75,18 +75,16 @@ def _next_transition(ball: Ball) -> Event: return null_event(time=np.inf) elif ball.state.s == const.spinning: - dtau_E = ptmath.get_spin_time( + dtau_E = get_spin_time( ball.state.rvw, ball.params.R, ball.params.u_sp, ball.params.g ) return spinning_stationary_transition(ball, ball.state.t + dtau_E) elif ball.state.s == const.rolling: - dtau_E_spin = ptmath.get_spin_time( + dtau_E_spin = get_spin_time( ball.state.rvw, ball.params.R, ball.params.u_sp, ball.params.g ) - dtau_E_roll = ptmath.get_roll_time( - ball.state.rvw, ball.params.u_r, ball.params.g - ) + dtau_E_roll = get_roll_time(ball.state.rvw, ball.params.u_r, ball.params.g) if dtau_E_spin > dtau_E_roll: return rolling_spinning_transition(ball, ball.state.t + dtau_E_roll) @@ -94,7 +92,7 @@ def _next_transition(ball: Ball) -> Event: return rolling_stationary_transition(ball, ball.state.t + dtau_E_roll) elif ball.state.s == const.sliding: - dtau_E = ptmath.get_slide_time( + dtau_E = get_slide_time( ball.state.rvw, ball.params.R, ball.params.u_s, ball.params.g ) return sliding_rolling_transition(ball, ball.state.t + dtau_E) diff --git a/pooltool/evolution/event_based/detect/__init__.py b/pooltool/evolution/event_based/detect/__init__.py new file mode 100644 index 00000000..310131f5 --- /dev/null +++ b/pooltool/evolution/event_based/detect/__init__.py @@ -0,0 +1,33 @@ +from pooltool.evolution.event_based.detect.ball_ball import ( + BallBallDetection, + BallBallDetectionStrategy, +) +from pooltool.evolution.event_based.detect.ball_cushion import ( + BallCCushionDetection, + BallCCushionDetectionStrategy, + BallLCushionDetection, + BallLCushionDetectionStrategy, +) +from pooltool.evolution.event_based.detect.ball_pocket import ( + BallPocketDetection, + BallPocketDetectionStrategy, +) +from pooltool.evolution.event_based.detect.detector import EventDetector +from pooltool.evolution.event_based.detect.stick_ball import ( + StickBallDetection, + StickBallDetectionStrategy, +) + +__all__ = [ + "EventDetector", + "BallBallDetection", + "BallBallDetectionStrategy", + "BallCCushionDetection", + "BallCCushionDetectionStrategy", + "BallLCushionDetection", + "BallLCushionDetectionStrategy", + "BallPocketDetection", + "BallPocketDetectionStrategy", + "StickBallDetection", + "StickBallDetectionStrategy", +] diff --git a/pooltool/evolution/event_based/detect/ball_ball.py b/pooltool/evolution/event_based/detect/ball_ball.py new file mode 100644 index 00000000..8813e5c3 --- /dev/null +++ b/pooltool/evolution/event_based/detect/ball_ball.py @@ -0,0 +1,93 @@ +from __future__ import annotations + +from itertools import combinations +from typing import Protocol + +import attrs +import numpy as np + +import pooltool.constants as const +import pooltool.ptmath as ptmath +from pooltool.events import Event, EventType, ball_ball_collision, null_event +from pooltool.evolution.event_based.cache import CollisionCache +from pooltool.physics.dimensionality import Dim +from pooltool.physics.motion.solve import ball_ball_collision_time +from pooltool.system.datatypes import System + + +class BallBallDetectionStrategy(Protocol): + """Ball-ball detection models must satisfy this protocol.""" + + dim: Dim + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... + + +@attrs.define +class BallBallDetection: + """Detects the next ball-ball collision in the system.""" + + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: + cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) + + for ball1, ball2 in combinations(shot.balls.values(), 2): + ball_pair = (ball1.id, ball2.id) + if ball_pair in cache: + continue + + ball1_state = ball1.state + ball1_params = ball1.params + + ball2_state = ball2.state + ball2_params = ball2.params + + if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed: + cache[ball_pair] = np.inf + elif ( + ball1_state.s in const.nontranslating + and ball2_state.s in const.nontranslating + ): + cache[ball_pair] = np.inf + elif ptmath.is_overlapping( + ball1_state.rvw, + ball2_state.rvw, + ball1_params.R, + ball2_params.R, + ): + cache[ball_pair] = shot.t + else: + dtau_E = ball_ball_collision_time( + rvw1=ball1_state.rvw, + rvw2=ball2_state.rvw, + s1=ball1_state.s, + s2=ball2_state.s, + mu1=( + ball1_params.u_s + if ball1_state.s == const.sliding + else ball1_params.u_r + ), + mu2=( + ball2_params.u_s + if ball2_state.s == const.sliding + else ball2_params.u_r + ), + m1=ball1_params.m, + m2=ball2_params.m, + g1=ball1_params.g, + g2=ball2_params.g, + R=ball1_params.R, + ) + cache[ball_pair] = shot.t + dtau_E + + if not cache: + return null_event(np.inf) + + ball_pair = min(cache, key=lambda k: cache[k]) + + return ball_ball_collision( + ball1=shot.balls[ball_pair[0]], + ball2=shot.balls[ball_pair[1]], + time=cache[ball_pair], + ) diff --git a/pooltool/evolution/event_based/detect/ball_cushion.py b/pooltool/evolution/event_based/detect/ball_cushion.py new file mode 100644 index 00000000..63954cc8 --- /dev/null +++ b/pooltool/evolution/event_based/detect/ball_cushion.py @@ -0,0 +1,138 @@ +from __future__ import annotations + +from typing import Protocol + +import attrs +import numpy as np + +import pooltool.constants as const +from pooltool.events import ( + Event, + EventType, + ball_circular_cushion_collision, + ball_linear_cushion_collision, + null_event, +) +from pooltool.evolution.event_based.cache import CollisionCache +from pooltool.physics.dimensionality import Dim +from pooltool.physics.motion.solve import ( + ball_circular_cushion_collision_time, + ball_linear_cushion_collision_time, +) +from pooltool.system.datatypes import System + + +class BallLCushionDetectionStrategy(Protocol): + """Ball-vs-linear-cushion-segment detection models must satisfy this protocol.""" + + dim: Dim + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... + + +@attrs.define +class BallLCushionDetection: + """Detects the next ball-vs-linear-cushion-segment collision in the system.""" + + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: + if not shot.table.has_linear_cushions: + return null_event(np.inf) + + cache = collision_cache.times.setdefault(EventType.BALL_LINEAR_CUSHION, {}) + + for ball in shot.balls.values(): + state = ball.state + params = ball.params + + for cushion in shot.table.cushion_segments.linear.values(): + obj_ids = (ball.id, cushion.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + dtau_E = ball_linear_cushion_collision_time( + rvw=state.rvw, + s=state.s, + lx=cushion.lx, + ly=cushion.ly, + l0=cushion.l0, + p1=cushion.p1, + p2=cushion.p2, + direction=cushion.direction, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + + cache[obj_ids] = shot.t + dtau_E + + obj_ids = min(cache, key=lambda k: cache[k]) + + return ball_linear_cushion_collision( + ball=shot.balls[obj_ids[0]], + cushion=shot.table.cushion_segments.linear[obj_ids[1]], + time=cache[obj_ids], + ) + + +class BallCCushionDetectionStrategy(Protocol): + """Ball-vs-circular-cushion-segment detection models must satisfy this protocol.""" + + dim: Dim + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... + + +@attrs.define +class BallCCushionDetection: + """Detects the next ball-vs-circular-cushion-segment collision in the system.""" + + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: + if not shot.table.has_circular_cushions: + return null_event(np.inf) + + cache = collision_cache.times.setdefault(EventType.BALL_CIRCULAR_CUSHION, {}) + + for ball in shot.balls.values(): + state = ball.state + params = ball.params + + for cushion in shot.table.cushion_segments.circular.values(): + obj_ids = (ball.id, cushion.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + dtau_E = ball_circular_cushion_collision_time( + rvw=state.rvw, + s=state.s, + a=cushion.a, + b=cushion.b, + r=cushion.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + cache[obj_ids] = shot.t + dtau_E + + ball_id, cushion_id = min(cache, key=lambda k: cache[k]) + + return ball_circular_cushion_collision( + ball=shot.balls[ball_id], + cushion=shot.table.cushion_segments.circular[cushion_id], + time=cache[(ball_id, cushion_id)], + ) diff --git a/pooltool/evolution/event_based/detect/ball_pocket.py b/pooltool/evolution/event_based/detect/ball_pocket.py new file mode 100644 index 00000000..66ebcf4a --- /dev/null +++ b/pooltool/evolution/event_based/detect/ball_pocket.py @@ -0,0 +1,69 @@ +from __future__ import annotations + +from typing import Protocol + +import attrs +import numpy as np + +import pooltool.constants as const +from pooltool.events import Event, EventType, ball_pocket_collision, null_event +from pooltool.evolution.event_based.cache import CollisionCache +from pooltool.physics.dimensionality import Dim +from pooltool.physics.motion.solve import ball_pocket_collision_time +from pooltool.system.datatypes import System + + +class BallPocketDetectionStrategy(Protocol): + """Ball-pocket detection models must satisfy this protocol.""" + + dim: Dim + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... + + +@attrs.define +class BallPocketDetection: + """Detects the next ball-pocket collision in the system.""" + + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: + if not shot.table.has_pockets: + return null_event(np.inf) + + cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) + + for ball in shot.balls.values(): + state = ball.state + params = ball.params + + for pocket in shot.table.pockets.values(): + obj_ids = (ball.id, pocket.id) + + if obj_ids in cache: + continue + + if ball.state.s in const.nontranslating: + cache[obj_ids] = np.inf + continue + + dtau_E = ball_pocket_collision_time( + rvw=state.rvw, + s=state.s, + a=pocket.a, + b=pocket.b, + r=pocket.radius, + mu=(params.u_s if state.s == const.sliding else params.u_r), + m=params.m, + g=params.g, + R=params.R, + ) + cache[obj_ids] = shot.t + dtau_E + + ball_id, pocket_id = min(cache, key=lambda k: cache[k]) + + return ball_pocket_collision( + ball=shot.balls[ball_id], + pocket=shot.table.pockets[pocket_id], + time=cache[(ball_id, pocket_id)], + ) diff --git a/pooltool/evolution/event_based/detect/detector.py b/pooltool/evolution/event_based/detect/detector.py new file mode 100644 index 00000000..082825d7 --- /dev/null +++ b/pooltool/evolution/event_based/detect/detector.py @@ -0,0 +1,161 @@ +from __future__ import annotations + +import attrs +import numpy as np + +import pooltool.ptmath as ptmath +from pooltool.events import Event, EventType, null_event +from pooltool.evolution.event_based.cache import CollisionCache, TransitionCache +from pooltool.evolution.event_based.detect.ball_ball import BallBallDetection +from pooltool.evolution.event_based.detect.ball_cushion import ( + BallCCushionDetection, + BallLCushionDetection, +) +from pooltool.evolution.event_based.detect.ball_pocket import BallPocketDetection +from pooltool.evolution.event_based.detect.stick_ball import StickBallDetection +from pooltool.physics.utils import get_ball_energy +from pooltool.system.datatypes import System + + +def _get_event_priority(event: Event, shot: System) -> tuple[int, float]: + """Compute priority for an event to resolve ties among simultaneous events. + + Returns a tuple (tier, energy) where: + - Lower tier = higher priority + - Higher energy = higher priority within the same tier + + Priority tiers: + - Tier 1: STICK_BALL (always first) + - Tier 2: Transitions and BALL_POCKET (can resolve without affecting others) + - Tier 3: BALL_BALL and ball-cushion collisions + + Args: + event: The event to compute priority for. + shot: The system state at the time the event was detected. + + Returns: + A tuple of (tier, energy) for sorting. + """ + event_type = event.event_type + + if event_type == EventType.NONE: + return (99, 0.0) + + if event_type == EventType.STICK_BALL: + return (1, shot.cue.V0**2) + + if event_type == EventType.BALL_POCKET: + ball_id = event.ids[0] + ball = shot.balls[ball_id] + energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + return (2, energy) + + if event_type.is_transition(): + ball_id = event.ids[0] + ball = shot.balls[ball_id] + energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + return (2, energy) + + if event_type == EventType.BALL_BALL: + ball1_id, ball2_id = event.ids + v1 = shot.balls[ball1_id].state.rvw[1] + v2 = shot.balls[ball2_id].state.rvw[1] + energy = ptmath.squared_norm3d(v1 - v2) + return (3, energy) + + if event_type in (EventType.BALL_LINEAR_CUSHION, EventType.BALL_CIRCULAR_CUSHION): + ball_id = event.ids[0] + ball = shot.balls[ball_id] + energy = get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) + return (3, energy) + + return (99, 0.0) + + +@attrs.define +class EventDetector: + """Bundles per-event-type detection strategies. + + Fields are typed as the concrete strategy class rather than the corresponding + ``*DetectionStrategy`` protocol. This keeps cattrs structuring trivial — cattrs + can structure into a concrete attrs class natively but cannot resolve a Protocol + without a discriminator. When a second implementation is added for a given event + type, this field type should be widened (e.g. to a union of concrete classes, or + to the protocol with a registry + structure hook that dispatches on a tag, in + the style of :class:`pooltool.physics.Resolver`). + + Attributes: + stick_ball: + Strategy for detecting the next stick-ball collision. + ball_ball: + Strategy for detecting the next ball-ball collision. + ball_linear_cushion: + Strategy for detecting the next ball-vs-linear-cushion-segment collision. + ball_circular_cushion: + Strategy for detecting the next ball-vs-circular-cushion-segment collision. + ball_pocket: + Strategy for detecting the next ball-pocket collision. + """ + + stick_ball: StickBallDetection = attrs.field(factory=StickBallDetection) + ball_ball: BallBallDetection = attrs.field(factory=BallBallDetection) + ball_linear_cushion: BallLCushionDetection = attrs.field( + factory=BallLCushionDetection + ) + ball_circular_cushion: BallCCushionDetection = attrs.field( + factory=BallCCushionDetection + ) + ball_pocket: BallPocketDetection = attrs.field(factory=BallPocketDetection) + + @classmethod + def default(cls) -> EventDetector: + return cls() + + def get_next_event( + self, + shot: System, + *, + transition_cache: TransitionCache | None = None, + collision_cache: CollisionCache | None = None, + ) -> Event: + """Return the soonest event across all event types. + + If multiple events occur at the same time, ties are broken by + :func:`_get_event_priority`. + """ + if transition_cache is None: + transition_cache = TransitionCache.create(shot) + if collision_cache is None: + collision_cache = CollisionCache.create() + + candidates: list[Event] = [] + + # Stick-ball collisions only occur at t=0 (shot initiation), so we skip this + # check after the first timestep as an optimization. Other collision types are + # always checked because they can occur at any time during simulation. Note: + # even at t=0, we still call the remaining detection strategies to fully + # populate the collision cache, which is needed by debug/introspection tools. + if shot.t == 0: + candidates.append(self.stick_ball.get_next(shot, collision_cache)) + + candidates.append(transition_cache.get_next()) + candidates.append(self.ball_ball.get_next(shot, collision_cache)) + candidates.append(self.ball_circular_cushion.get_next(shot, collision_cache)) + candidates.append(self.ball_linear_cushion.get_next(shot, collision_cache)) + candidates.append(self.ball_pocket.get_next(shot, collision_cache)) + + min_time = min(event.time for event in candidates) + + if min_time == np.inf: + return null_event(time=np.inf) + + simultaneous = [e for e in candidates if e.time == min_time] + + if len(simultaneous) == 1: + return simultaneous[0] + + def sort_key(e: Event) -> tuple[int, float]: + tier, energy = _get_event_priority(e, shot) + return (tier, -energy) + + return min(simultaneous, key=sort_key) diff --git a/pooltool/evolution/event_based/detect/stick_ball.py b/pooltool/evolution/event_based/detect/stick_ball.py new file mode 100644 index 00000000..b9c65603 --- /dev/null +++ b/pooltool/evolution/event_based/detect/stick_ball.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from typing import Protocol + +import attrs +import numpy as np + +from pooltool.events import Event, EventType, stick_ball_collision +from pooltool.evolution.event_based._utils import _system_has_energy +from pooltool.evolution.event_based.cache import CollisionCache +from pooltool.physics.dimensionality import Dim +from pooltool.system.datatypes import System + + +class StickBallDetectionStrategy(Protocol): + """Stick-ball detection models must satisfy this protocol.""" + + dim: Dim + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: ... + + +@attrs.define +class StickBallDetection: + """Stick-ball collision detection. + + Stick-ball events fire only at t=0, when the system is at rest and a cue strike is + queued (V0 > 0). + """ + + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) + + def get_next(self, shot: System, collision_cache: CollisionCache) -> Event: + cache = collision_cache.times.setdefault(EventType.STICK_BALL, {}) + + obj_ids = (shot.cue.id, shot.cue.cue_ball_id) + + if obj_ids in cache: + return stick_ball_collision( + stick=shot.cue, + ball=shot.balls[shot.cue.cue_ball_id], + time=cache[obj_ids], + ) + + if shot.t == 0 and not _system_has_energy(shot) and shot.cue.V0 > 0: + cache[obj_ids] = 0.0 + else: + cache[obj_ids] = np.inf + + return stick_ball_collision( + stick=shot.cue, + ball=shot.balls[shot.cue.cue_ball_id], + time=cache[obj_ids], + ) diff --git a/pooltool/evolution/event_based/introspection.py b/pooltool/evolution/event_based/introspection.py index 16d0d204..a641c72e 100644 --- a/pooltool/evolution/event_based/introspection.py +++ b/pooltool/evolution/event_based/introspection.py @@ -33,13 +33,13 @@ ball_pocket_collision, stick_ball_collision, ) +from pooltool.evolution.engine import SimulationEngine from pooltool.evolution.event_based.cache import CollisionCache, TransitionCache from pooltool.evolution.event_based.config import INCLUDED_EVENTS from pooltool.evolution.event_based.simulate import ( DEFAULT_ENGINE, _SimulationState, ) -from pooltool.physics.engine import PhysicsEngine from pooltool.serialize import conversion from pooltool.serialize.serializers import Pathish from pooltool.system.datatypes import System @@ -114,7 +114,7 @@ class SimulationSnapshot: next_event: Event collision_cache: CollisionCache transition_cache: TransitionCache - engine: PhysicsEngine + engine: SimulationEngine def get_prospective_events(self) -> list[Event]: """Get all prospective events. @@ -166,7 +166,7 @@ def post_resolve_system(self, event: Event) -> System: @attrs.define class SimulationSnapshotSequence: steps: list[SimulationSnapshot] = attrs.field(factory=list) - engine: PhysicsEngine = attrs.field(factory=PhysicsEngine) + engine: SimulationEngine = attrs.field(factory=SimulationEngine) def add(self, snapshot: SimulationSnapshot) -> None: self.steps.append(snapshot) @@ -190,7 +190,7 @@ def load(cls, path: Pathish) -> SimulationSnapshotSequence: def simulate_with_snapshots( shot: System, output_path: Path | None = None, - engine: PhysicsEngine | None = None, + engine: SimulationEngine | None = None, t_final: float | None = None, include: set[EventType] = INCLUDED_EVENTS, max_events: int = 0, diff --git a/pooltool/evolution/event_based/simulate.py b/pooltool/evolution/event_based/simulate.py index dda487ee..f82996e8 100755 --- a/pooltool/evolution/event_based/simulate.py +++ b/pooltool/evolution/event_based/simulate.py @@ -2,113 +2,25 @@ from __future__ import annotations -from itertools import combinations - import attrs import numpy as np -import pooltool.constants as const import pooltool.physics.evolve as evolve -import pooltool.ptmath as ptmath -from pooltool.events import ( - Event, - EventType, - ball_ball_collision, - ball_circular_cushion_collision, - ball_linear_cushion_collision, - ball_pocket_collision, - null_event, - stick_ball_collision, -) +from pooltool.events import Event, EventType, null_event from pooltool.evolution.continuous import continuize -from pooltool.evolution.event_based import solve +from pooltool.evolution.engine import SimulationEngine from pooltool.evolution.event_based.cache import CollisionCache, TransitionCache from pooltool.evolution.event_based.config import INCLUDED_EVENTS from pooltool.objects.ball.datatypes import BallState -from pooltool.physics.engine import PhysicsEngine from pooltool.system.datatypes import System -DEFAULT_ENGINE = PhysicsEngine() - - -def _system_has_energy(system: System) -> bool: - """Check whether the system has any energy. - - Notes: - - Returns False as soon as first energetic ball is iterated through. - - Cue energy (e.g. setting system.cue.V0 > 0 doesn't count as energy). - """ - return any( - bool( - ptmath.get_ball_energy( - ball.state.rvw, - ball.params.R, - ball.params.m, - ) - ) - for ball in system.balls.values() - ) - - -def get_event_priority(event: Event, shot: System) -> tuple[int, float]: - """Compute priority for an event to resolve ties among simultaneous events. - - Returns a tuple (tier, energy) where: - - Lower tier = higher priority - - Higher energy = higher priority within the same tier - - Priority tiers: - - Tier 1: STICK_BALL (always first) - - Tier 2: Transitions and BALL_POCKET (can resolve without affecting others) - - Tier 3: BALL_BALL and ball-cushion collisions - - Args: - event: The event to compute priority for. - shot: The system state at the time the event was detected. - - Returns: - A tuple of (tier, energy) for sorting. - """ - event_type = event.event_type - - if event_type == EventType.NONE: - return (99, 0.0) - - if event_type == EventType.STICK_BALL: - return (1, shot.cue.V0**2) - - if event_type == EventType.BALL_POCKET: - ball_id = event.ids[0] - ball = shot.balls[ball_id] - energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) - return (2, energy) - - if event_type.is_transition(): - ball_id = event.ids[0] - ball = shot.balls[ball_id] - energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) - return (2, energy) - - if event_type == EventType.BALL_BALL: - ball1_id, ball2_id = event.ids - v1 = shot.balls[ball1_id].state.rvw[1] - v2 = shot.balls[ball2_id].state.rvw[1] - energy = ptmath.squared_norm3d(v1 - v2) - return (3, energy) - - if event_type in (EventType.BALL_LINEAR_CUSHION, EventType.BALL_CIRCULAR_CUSHION): - ball_id = event.ids[0] - ball = shot.balls[ball_id] - energy = ptmath.get_ball_energy(ball.state.rvw, ball.params.R, ball.params.m) - return (3, energy) - - return (99, 0.0) +DEFAULT_ENGINE = SimulationEngine() @attrs.define class _SimulationState: shot: System - engine: PhysicsEngine + engine: SimulationEngine t_final: float | None = None include: set[EventType] = INCLUDED_EVENTS @@ -128,7 +40,7 @@ def init(self) -> None: self.shot._update_history(null_event(time=0)) def step(self) -> Event: - event = get_next_event( + event = self.engine.detector.get_next_event( self.shot, transition_cache=self.transition_cache, collision_cache=self.collision_cache, @@ -190,7 +102,7 @@ def evolve(shot: System, dt: float): def simulate( shot: System, - engine: PhysicsEngine | None = None, + engine: SimulationEngine | None = None, inplace: bool = False, continuous: bool = False, dt: float | None = None, @@ -206,7 +118,7 @@ def simulate( otherwise there will be nothing to simulate. engine: The engine holds all of the physics. You can instantiate your very own - :class:`pooltool.physics.PhysicsEngine` object, or you can modify + :class:`pooltool.evolution.SimulationEngine` object, or you can modify ``~/.config/pooltool/physics/resolver.json`` to change the default engine. inplace: By default, a copy of the passed system is simulated and returned. This @@ -292,312 +204,3 @@ def simulate( continuize(sim.shot, dt=0.01 if dt is None else dt, inplace=True) return sim.shot - - -def get_next_event( - shot: System, - *, - transition_cache: TransitionCache | None = None, - collision_cache: CollisionCache | None = None, -) -> Event: - # If not passed, unpopulated caches are initialized to pass to delegate functions. - # These empty caches will be populated by the delegate functions, but then thrown - # away when this function returns. - if transition_cache is None: - transition_cache = TransitionCache.create(shot) - if collision_cache is None: - collision_cache = CollisionCache.create() - - # Collect all candidate events from each detection function. - candidates: list[Event] = [] - - # Stick-ball collisions only occur at t=0 (shot initiation), so we skip this - # check after the first timestep as an optimization. Other collision types are - # always checked because they can occur at any time during simulation. Note: even - # at t=0, we still call the remaining detection functions to fully populate the - # collision cache, which is needed by debug/introspection tools. - if shot.t == 0: - candidates.append( - get_next_stick_ball_collision(shot, collision_cache=collision_cache) - ) - - candidates.append(transition_cache.get_next()) - candidates.append( - get_next_ball_ball_collision(shot, collision_cache=collision_cache) - ) - candidates.append( - get_next_ball_circular_cushion_event(shot, collision_cache=collision_cache) - ) - candidates.append( - get_next_ball_linear_cushion_collision(shot, collision_cache=collision_cache) - ) - candidates.append( - get_next_ball_pocket_collision(shot, collision_cache=collision_cache) - ) - - # Find the earliest time among all candidates. - min_time = min(event.time for event in candidates) - - if min_time == np.inf: - return null_event(time=np.inf) - - # Filter to only events occurring at the earliest time. - simultaneous = [e for e in candidates if e.time == min_time] - - if len(simultaneous) == 1: - return simultaneous[0] - - # When multiple events occur at the same time, select by priority tier, then by - # energy within the tier (higher energy first). - def sort_key(e: Event) -> tuple[int, float]: - tier, energy = get_event_priority(e, shot) - return (tier, -energy) - - return min(simultaneous, key=sort_key) - - -def get_next_stick_ball_collision( - shot: System, collision_cache: CollisionCache -) -> Event: - """Returns next stick-ball collision""" - - cache = collision_cache.times.setdefault(EventType.STICK_BALL, {}) - - obj_ids = (shot.cue.id, shot.cue.cue_ball_id) - - if obj_ids in cache: - return stick_ball_collision( - stick=shot.cue, - ball=shot.balls[shot.cue.cue_ball_id], - time=cache[obj_ids], - ) - - if shot.t == 0 and not _system_has_energy(shot) and shot.cue.V0 > 0: - cache[obj_ids] = 0.0 - else: - cache[obj_ids] = np.inf - - return stick_ball_collision( - stick=shot.cue, - ball=shot.balls[shot.cue.cue_ball_id], - time=cache[obj_ids], - ) - - -def get_next_ball_ball_collision( - shot: System, - collision_cache: CollisionCache, -) -> Event: - """Returns next ball-ball collision""" - - cache = collision_cache.times.setdefault(EventType.BALL_BALL, {}) - - for ball1, ball2 in combinations(shot.balls.values(), 2): - ball_pair = (ball1.id, ball2.id) - if ball_pair in cache: - continue - - ball1_state = ball1.state - ball1_params = ball1.params - - ball2_state = ball2.state - ball2_params = ball2.params - - if ball1_state.s == const.pocketed or ball2_state.s == const.pocketed: - cache[ball_pair] = np.inf - elif ( - ball1_state.s in const.nontranslating - and ball2_state.s in const.nontranslating - ): - cache[ball_pair] = np.inf - elif ptmath.is_overlapping( - ball1_state.rvw, - ball2_state.rvw, - ball1_params.R, - ball2_params.R, - ): - cache[ball_pair] = shot.t - else: - dtau_E = solve.ball_ball_collision_time( - rvw1=ball1_state.rvw, - rvw2=ball2_state.rvw, - s1=ball1_state.s, - s2=ball2_state.s, - mu1=( - ball1_params.u_s - if ball1_state.s == const.sliding - else ball1_params.u_r - ), - mu2=( - ball2_params.u_s - if ball2_state.s == const.sliding - else ball2_params.u_r - ), - m1=ball1_params.m, - m2=ball2_params.m, - g1=ball1_params.g, - g2=ball2_params.g, - R=ball1_params.R, - ) - cache[ball_pair] = shot.t + dtau_E - - # The cache is now populated and up-to-date - - if not cache: - return null_event(np.inf) - - ball_pair = min(cache, key=lambda k: cache[k]) - - return ball_ball_collision( - ball1=shot.balls[ball_pair[0]], - ball2=shot.balls[ball_pair[1]], - time=cache[ball_pair], - ) - - -def get_next_ball_circular_cushion_event( - shot: System, - collision_cache: CollisionCache, -) -> Event: - """Returns next ball-cushion collision (circular cushion segment)""" - - if not shot.table.has_circular_cushions: - return null_event(np.inf) - - cache = collision_cache.times.setdefault(EventType.BALL_CIRCULAR_CUSHION, {}) - - for ball in shot.balls.values(): - state = ball.state - params = ball.params - - for cushion in shot.table.cushion_segments.circular.values(): - obj_ids = (ball.id, cushion.id) - - if obj_ids in cache: - continue - - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue - - dtau_E = solve.ball_circular_cushion_collision_time( - rvw=state.rvw, - s=state.s, - a=cushion.a, - b=cushion.b, - r=cushion.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - cache[obj_ids] = shot.t + dtau_E - - # The cache is now populated and up-to-date - - ball_id, cushion_id = min(cache, key=lambda k: cache[k]) - - return ball_circular_cushion_collision( - ball=shot.balls[ball_id], - cushion=shot.table.cushion_segments.circular[cushion_id], - time=cache[(ball_id, cushion_id)], - ) - - -def get_next_ball_linear_cushion_collision( - shot: System, collision_cache: CollisionCache -) -> Event: - """Returns next ball-cushion collision (linear cushion segment)""" - - if not shot.table.has_linear_cushions: - return null_event(np.inf) - - cache = collision_cache.times.setdefault(EventType.BALL_LINEAR_CUSHION, {}) - - for ball in shot.balls.values(): - state = ball.state - params = ball.params - - for cushion in shot.table.cushion_segments.linear.values(): - obj_ids = (ball.id, cushion.id) - - if obj_ids in cache: - continue - - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue - - dtau_E = solve.ball_linear_cushion_collision_time( - rvw=state.rvw, - s=state.s, - lx=cushion.lx, - ly=cushion.ly, - l0=cushion.l0, - p1=cushion.p1, - p2=cushion.p2, - direction=cushion.direction, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - - cache[obj_ids] = shot.t + dtau_E - - obj_ids = min(cache, key=lambda k: cache[k]) - - return ball_linear_cushion_collision( - ball=shot.balls[obj_ids[0]], - cushion=shot.table.cushion_segments.linear[obj_ids[1]], - time=cache[obj_ids], - ) - - -def get_next_ball_pocket_collision( - shot: System, - collision_cache: CollisionCache, -) -> Event: - """Returns next ball-pocket collision""" - - if not shot.table.has_pockets: - return null_event(np.inf) - - cache = collision_cache.times.setdefault(EventType.BALL_POCKET, {}) - - for ball in shot.balls.values(): - state = ball.state - params = ball.params - - for pocket in shot.table.pockets.values(): - obj_ids = (ball.id, pocket.id) - - if obj_ids in cache: - continue - - if ball.state.s in const.nontranslating: - cache[obj_ids] = np.inf - continue - - dtau_E = solve.ball_pocket_collision_time( - rvw=state.rvw, - s=state.s, - a=pocket.a, - b=pocket.b, - r=pocket.radius, - mu=(params.u_s if state.s == const.sliding else params.u_r), - m=params.m, - g=params.g, - R=params.R, - ) - cache[obj_ids] = shot.t + dtau_E - - # The cache is now populated and up-to-date - - ball_id, pocket_id = min(cache, key=lambda k: cache[k]) - - return ball_pocket_collision( - ball=shot.balls[ball_id], - pocket=shot.table.pockets[pocket_id], - time=cache[(ball_id, pocket_id)], - ) diff --git a/pooltool/objects/cue/render.py b/pooltool/objects/cue/render.py index ea9a1b90..ea6e0c28 100644 --- a/pooltool/objects/cue/render.py +++ b/pooltool/objects/cue/render.py @@ -10,7 +10,7 @@ from pooltool.objects.ball.render import BallRender from pooltool.objects.cue.datatypes import Cue from pooltool.objects.datatypes import Render -from pooltool.ptmath.utils import tip_center_offset, tip_contact_offset +from pooltool.physics.utils import tip_center_offset, tip_contact_offset class CueRender(Render): diff --git a/pooltool/objects/table/components.py b/pooltool/objects/table/components.py index c93a8a83..7bd41777 100644 --- a/pooltool/objects/table/components.py +++ b/pooltool/objects/table/components.py @@ -36,7 +36,7 @@ class CushionDirection: Note: This used to inherit from ``Enum``, but accessing the cushion direction in - ``get_next_ball_linear_cushion_collision`` took up 20% of the function's + ball-vs-linear-cushion-segment detection took up 20% of the function's runtime, so it was removed. """ diff --git a/pooltool/physics/__init__.py b/pooltool/physics/__init__.py index 91a5c73f..1a11e02e 100644 --- a/pooltool/physics/__init__.py +++ b/pooltool/physics/__init__.py @@ -1,6 +1,5 @@ """Physics subpackage for pooltool""" -from pooltool.physics.engine import PhysicsEngine from pooltool.physics.evolve import ( evolve_ball_motion, ) @@ -35,9 +34,17 @@ BallTransitionModel, ball_transition_models, ) +from pooltool.physics.utils import ( + get_ball_energy, + get_roll_time, + get_slide_time, + get_spin_time, + get_u_vec, + rel_velocity, + surface_velocity, +) __all__ = [ - "PhysicsEngine", # Resolve "display_models", "Resolver", @@ -48,6 +55,13 @@ "BallPocketModel", "StickBallModel", "BallTransitionModel", + "rel_velocity", + "surface_velocity", + "get_u_vec", + "get_slide_time", + "get_roll_time", + "get_spin_time", + "get_ball_energy", "ball_ball_models", "BallBallFrictionModel", "ball_ball_friction_models", diff --git a/pooltool/physics/dimensionality.py b/pooltool/physics/dimensionality.py new file mode 100644 index 00000000..d72f3783 --- /dev/null +++ b/pooltool/physics/dimensionality.py @@ -0,0 +1,26 @@ +from pooltool.utils.strenum import StrEnum, auto + + +class Dim(StrEnum): + """Dimensionality capability declaration for physics strategies. + + Each Resolver and EventDetector strategy declares one of these as a class-level + attribute. :class:`pooltool.evolution.SimulationEngine` reads these once at + construction to validate that the bundled strategies are compatible with its + ``is_3d`` setting. + + A strategy's ``dim`` is a promise about its behavior, *not* a mode switch. + ``BOTH`` means the strategy behaves identically in either mode; it does not mean + the strategy branches internally based on mode. If a strategy would behave + differently in 2D vs 3D, it should be split into separate ``TWO`` and ``THREE`` + classes. + + Members: + TWO: Safe only when ``SimulationEngine.is_3d`` is ``False``. + THREE: Safe only when ``SimulationEngine.is_3d`` is ``True``. + BOTH: Behavior identical in either mode; safe always. + """ + + TWO = auto() + THREE = auto() + BOTH = auto() diff --git a/pooltool/physics/engine.py b/pooltool/physics/engine.py deleted file mode 100644 index e3df9891..00000000 --- a/pooltool/physics/engine.py +++ /dev/null @@ -1,29 +0,0 @@ -"""The physics engine of pooltool""" - -from __future__ import annotations - -import attrs - -from pooltool.physics.resolve import Resolver - - -@attrs.define -class PhysicsEngine: - """A billiards engine for pluggable physics. - - Important: - Currently, only event resolution is a part of this class. The sliding, rolling, - and spinning ball trajectory evolution is currently "hard-coded", however can in - theory be added to this class to enable alternative trajectory models. - - Attributes: - resolver: - The physics engine responsible for resolving events. - """ - - resolver: Resolver = attrs.field(factory=Resolver.default) - - -__all__ = [ - "PhysicsEngine", -] diff --git a/pooltool/physics/evolve/__init__.py b/pooltool/physics/evolve/__init__.py index 250f66e2..32e0ed3c 100644 --- a/pooltool/physics/evolve/__init__.py +++ b/pooltool/physics/evolve/__init__.py @@ -6,7 +6,7 @@ https://ekiefl.github.io/2020/04/24/pooltool-theory/#3-ball-with-arbitrary-spin -The code should be configurable and passed to `PhysicsEngine` in `physics/engine.py`, +The code should be configurable and passed to `SimulationEngine` in `evolution/engine.py`, just like the `Resolver` class in `physics/resolve/resolver.py` """ @@ -16,6 +16,12 @@ import pooltool.constants as const import pooltool.ptmath as ptmath +from pooltool.physics.utils import ( + get_roll_time, + get_slide_time, + get_spin_time, + rel_velocity, +) @jit(nopython=True, cache=const.use_numba_cache) @@ -42,7 +48,7 @@ def evolve_ball_motion( return rvw, state if state == const.sliding: - dtau_E_slide = ptmath.get_slide_time(rvw, R, u_s, g) + dtau_E_slide = get_slide_time(rvw, R, u_s, g) if t >= dtau_E_slide: rvw = _evolve_slide_state(rvw, R, m, u_s, u_sp, g, dtau_E_slide) @@ -52,7 +58,7 @@ def evolve_ball_motion( return _evolve_slide_state(rvw, R, m, u_s, u_sp, g, t), const.sliding if state == const.rolling: - dtau_E_roll = ptmath.get_roll_time(rvw, u_r, g) + dtau_E_roll = get_roll_time(rvw, u_r, g) if t >= dtau_E_roll: rvw = _evolve_roll_state(rvw, R, u_r, u_sp, g, dtau_E_roll) @@ -62,7 +68,7 @@ def evolve_ball_motion( return _evolve_roll_state(rvw, R, u_r, u_sp, g, t), const.rolling if state == const.spinning: - dtau_E_spin = ptmath.get_spin_time(rvw, R, u_sp, g) + dtau_E_spin = get_spin_time(rvw, R, u_sp, g) if t >= dtau_E_spin: return ( @@ -94,9 +100,7 @@ def _evolve_slide_state( rvw_B0 = ptmath.coordinate_rotation(rvw.T, -phi).T # Relative velocity unit vector in ball frame - u_0 = ptmath.coordinate_rotation( - ptmath.unit_vector(ptmath.rel_velocity(rvw, R)), -phi - ) + u_0 = ptmath.coordinate_rotation(ptmath.unit_vector(rel_velocity(rvw, R)), -phi) # Calculate quantities according to the ball frame. NOTE w_B in this code block # is only accurate of the x and y evolution of angular velocity. z evolution of diff --git a/pooltool/physics/motion/__init__.py b/pooltool/physics/motion/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/pooltool/evolution/event_based/solve.py b/pooltool/physics/motion/solve.py similarity index 99% rename from pooltool/evolution/event_based/solve.py rename to pooltool/physics/motion/solve.py index c1a80ed1..af2e4b4e 100644 --- a/pooltool/evolution/event_based/solve.py +++ b/pooltool/physics/motion/solve.py @@ -7,6 +7,7 @@ import pooltool.constants as const import pooltool.physics.evolve as evolve import pooltool.ptmath as ptmath +from pooltool.physics.utils import rel_velocity from pooltool.ptmath.roots import quartic from pooltool.ptmath.roots.core import get_real_positive_smallest_root @@ -87,7 +88,7 @@ def get_u( if s == const.rolling: return np.array([1, 0, 0], dtype=np.float64) - rel_vel = ptmath.rel_velocity(rvw, R) + rel_vel = rel_velocity(rvw, R) if (rel_vel == 0).all(): return np.array([1, 0, 0], dtype=np.float64) diff --git a/pooltool/physics/resolve/ball_ball/core.py b/pooltool/physics/resolve/ball_ball/core.py index 0a90ebad..2c6be6da 100644 --- a/pooltool/physics/resolve/ball_ball/core.py +++ b/pooltool/physics/resolve/ball_ball/core.py @@ -6,6 +6,7 @@ import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball +from pooltool.physics.dimensionality import Dim class _BaseStrategy(Protocol): @@ -19,6 +20,8 @@ def resolve( class BallBallCollisionStrategy(_BaseStrategy, Protocol): """Ball-ball collision models must satisfy this protocol""" + dim: Dim + def solve(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: """This method resolves a ball-ball collision""" ... diff --git a/pooltool/physics/resolve/ball_ball/friction.py b/pooltool/physics/resolve/ball_ball/friction.py index f7a50ee5..ffe4360e 100644 --- a/pooltool/physics/resolve/ball_ball/friction.py +++ b/pooltool/physics/resolve/ball_ball/friction.py @@ -5,6 +5,7 @@ import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball +from pooltool.physics.utils import tangent_surface_velocity from pooltool.utils.strenum import StrEnum, auto @@ -46,12 +47,8 @@ class AlciatoreBallBallFriction: def calculate_friction(self, ball1: Ball, ball2: Ball) -> float: unit_normal = ptmath.unit_vector(ball2.xyz - ball1.xyz) - v1_c = ptmath.tangent_surface_velocity( - ball1.state.rvw, unit_normal, ball1.params.R - ) - v2_c = ptmath.tangent_surface_velocity( - ball2.state.rvw, -unit_normal, ball2.params.R - ) + v1_c = tangent_surface_velocity(ball1.state.rvw, unit_normal, ball1.params.R) + v2_c = tangent_surface_velocity(ball2.state.rvw, -unit_normal, ball2.params.R) relative_surface_speed = ptmath.norm3d(v1_c - v2_c) return self.a + self.b * math.exp(-self.c * relative_surface_speed) diff --git a/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py b/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py index 4fef09bc..6e3c2035 100644 --- a/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py +++ b/pooltool/physics/resolve/ball_ball/frictional_inelastic/__init__.py @@ -5,12 +5,14 @@ import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball, BallState +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_ball.core import CoreBallBallCollision from pooltool.physics.resolve.ball_ball.friction import ( AlciatoreBallBallFriction, BallBallFrictionStrategy, ) from pooltool.physics.resolve.models import BallBallModel +from pooltool.physics.utils import surface_velocity @jit(nopython=True, cache=const.use_numba_cache) @@ -43,8 +45,8 @@ def _resolve_ball_ball(rvw1, rvw2, R, u_b, e_b): rvw1_f = rvw1.copy() rvw2_f = rvw2.copy() - v1_c = ptmath.surface_velocity(rvw1, unit_x, R) - v2_c = ptmath.surface_velocity(rvw2, -unit_x, R) + v1_c = surface_velocity(rvw1, unit_x, R) + v2_c = surface_velocity(rvw2, -unit_x, R) v12_c = v1_c - v2_c has_relative_velocity = ptmath.norm3d(v12_c) > const.EPS @@ -61,8 +63,8 @@ def _resolve_ball_ball(rvw1, rvw2, R, u_b, e_b): rvw2_f[2] = rvw2[2] + D_w1 # calculate new relative contact velocity - v1_c_slip = ptmath.surface_velocity(rvw1_f, unit_x, R) - v2_c_slip = ptmath.surface_velocity(rvw2_f, -unit_x, R) + v1_c_slip = surface_velocity(rvw1_f, unit_x, R) + v2_c_slip = surface_velocity(rvw2_f, -unit_x, R) v12_c_slip = v1_c_slip - v2_c_slip # if there was no relative velocity to begin with, or if slip changed directions, @@ -115,6 +117,7 @@ class FrictionalInelastic(CoreBallBallCollision): model: BallBallModel = attrs.field( default=BallBallModel.FRICTIONAL_INELASTIC, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: """Resolves the collision.""" diff --git a/pooltool/physics/resolve/ball_ball/frictional_mathavan/__init__.py b/pooltool/physics/resolve/ball_ball/frictional_mathavan/__init__.py index 87e52413..cdc41452 100644 --- a/pooltool/physics/resolve/ball_ball/frictional_mathavan/__init__.py +++ b/pooltool/physics/resolve/ball_ball/frictional_mathavan/__init__.py @@ -6,6 +6,7 @@ import pooltool.constants as const from pooltool.objects.ball.datatypes import Ball, BallState +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_ball.core import CoreBallBallCollision from pooltool.physics.resolve.ball_ball.friction import ( AlciatoreBallBallFriction, @@ -240,6 +241,7 @@ class FrictionalMathavan(CoreBallBallCollision): model: BallBallModel = attrs.field( default=BallBallModel.FRICTIONAL_MATHAVAN, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: """Resolve ball-ball collision via Mathavan et al. (2014). diff --git a/pooltool/physics/resolve/ball_ball/frictionless_elastic/__init__.py b/pooltool/physics/resolve/ball_ball/frictionless_elastic/__init__.py index ea873186..5ae5e8f8 100644 --- a/pooltool/physics/resolve/ball_ball/frictionless_elastic/__init__.py +++ b/pooltool/physics/resolve/ball_ball/frictionless_elastic/__init__.py @@ -4,6 +4,7 @@ import pooltool.constants as const import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball, BallState +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_ball.core import CoreBallBallCollision from pooltool.physics.resolve.models import BallBallModel @@ -40,6 +41,7 @@ class FrictionlessElastic(CoreBallBallCollision): model: BallBallModel = attrs.field( default=BallBallModel.FRICTIONLESS_ELASTIC, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: """Resolves the collision.""" diff --git a/pooltool/physics/resolve/ball_cushion/core.py b/pooltool/physics/resolve/ball_cushion/core.py index 0baa6d6c..49ca1319 100644 --- a/pooltool/physics/resolve/ball_cushion/core.py +++ b/pooltool/physics/resolve/ball_cushion/core.py @@ -10,6 +10,7 @@ CircularCushionSegment, LinearCushionSegment, ) +from pooltool.physics.dimensionality import Dim class _BaseLinearStrategy(Protocol): @@ -31,6 +32,8 @@ def resolve( class BallLCushionCollisionStrategy(_BaseLinearStrategy, Protocol): """Ball-linear cushion collision models must satisfy this protocol""" + dim: Dim + def solve( self, ball: Ball, cushion: LinearCushionSegment ) -> tuple[Ball, LinearCushionSegment]: @@ -41,6 +44,8 @@ def solve( class BallCCushionCollisionStrategy(_BaseCircularStrategy, Protocol): """Ball-circular cushion collision models must satisfy this protocol""" + dim: Dim + def solve( self, ball: Ball, cushion: CircularCushionSegment ) -> tuple[Ball, CircularCushionSegment]: diff --git a/pooltool/physics/resolve/ball_cushion/han_2005/model.py b/pooltool/physics/resolve/ball_cushion/han_2005/model.py index 89334333..bce3f30a 100644 --- a/pooltool/physics/resolve/ball_cushion/han_2005/model.py +++ b/pooltool/physics/resolve/ball_cushion/han_2005/model.py @@ -9,6 +9,7 @@ Cushion, LinearCushionSegment, ) +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_cushion.core import ( CoreBallCCushionCollision, CoreBallLCushionCollision, @@ -111,6 +112,7 @@ class Han2005Linear(CoreBallLCushionCollision): model: BallLCushionModel = attrs.field( default=BallLCushionModel.HAN_2005, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: LinearCushionSegment @@ -123,6 +125,7 @@ class Han2005Circular(CoreBallCCushionCollision): model: BallCCushionModel = attrs.field( default=BallCCushionModel.HAN_2005, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: CircularCushionSegment diff --git a/pooltool/physics/resolve/ball_cushion/impulse_frictional_inelastic/model.py b/pooltool/physics/resolve/ball_cushion/impulse_frictional_inelastic/model.py index ec45a523..b8def9c7 100644 --- a/pooltool/physics/resolve/ball_cushion/impulse_frictional_inelastic/model.py +++ b/pooltool/physics/resolve/ball_cushion/impulse_frictional_inelastic/model.py @@ -7,6 +7,7 @@ Cushion, LinearCushionSegment, ) +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_cushion.core import ( CoreBallCCushionCollision, CoreBallLCushionCollision, @@ -39,6 +40,7 @@ class ImpulseFrictionalInelasticLinear(CoreBallLCushionCollision): model: BallLCushionModel = attrs.field( default=BallLCushionModel.IMPULSE_FRICTIONAL_INELASTIC, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: LinearCushionSegment @@ -51,6 +53,7 @@ class ImpulseFrictionalInelasticCircular(CoreBallCCushionCollision): model: BallCCushionModel = attrs.field( default=BallCCushionModel.IMPULSE_FRICTIONAL_INELASTIC, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: CircularCushionSegment diff --git a/pooltool/physics/resolve/ball_cushion/mathavan_2010/model.py b/pooltool/physics/resolve/ball_cushion/mathavan_2010/model.py index b967503a..b3147358 100644 --- a/pooltool/physics/resolve/ball_cushion/mathavan_2010/model.py +++ b/pooltool/physics/resolve/ball_cushion/mathavan_2010/model.py @@ -12,6 +12,7 @@ Cushion, LinearCushionSegment, ) +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_cushion.core import ( CoreBallCCushionCollision, CoreBallLCushionCollision, @@ -700,6 +701,7 @@ class Mathavan2010Linear(CoreBallLCushionCollision): model: BallLCushionModel = attrs.field( default=BallLCushionModel.MATHAVAN_2010, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: LinearCushionSegment @@ -752,6 +754,7 @@ class Mathavan2010Circular(CoreBallCCushionCollision): model: BallCCushionModel = attrs.field( default=BallCCushionModel.MATHAVAN_2010, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: CircularCushionSegment diff --git a/pooltool/physics/resolve/ball_cushion/stronge_compliant/model.py b/pooltool/physics/resolve/ball_cushion/stronge_compliant/model.py index 9f46fe76..e6cd552a 100644 --- a/pooltool/physics/resolve/ball_cushion/stronge_compliant/model.py +++ b/pooltool/physics/resolve/ball_cushion/stronge_compliant/model.py @@ -11,6 +11,7 @@ Cushion, LinearCushionSegment, ) +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_cushion.core import ( CoreBallCCushionCollision, CoreBallLCushionCollision, @@ -19,6 +20,7 @@ from pooltool.physics.resolve.stronge_compliant import ( resolve_collinear_compliant_frictional_inelastic_collision, ) +from pooltool.physics.utils import surface_velocity logger = logging.getLogger(__name__) @@ -29,9 +31,7 @@ def _solve(ball: Ball, cushion: Cushion, omega_ratio: float) -> tuple[Ball, Cush logger.debug(f"v={rvw[1]}, w={rvw[2]}") normal_direction = cushion.get_normal_3d(ball.xyz) - relative_contact_velocity = ptmath.surface_velocity( - rvw, -normal_direction, ball.params.R - ) + relative_contact_velocity = surface_velocity(rvw, -normal_direction, ball.params.R) v_n_0, v_t_0, tangent_direction = ptmath.decompose_normal_tangent( relative_contact_velocity, normal_direction, True @@ -78,7 +78,7 @@ def _solve(ball: Ball, cushion: Cushion, omega_ratio: float) -> tuple[Ball, Cush # and from `rvw` which was modified based on stronge output, # then verify that they're equal v_c_f = v_n_f * normal_direction + v_t_f * tangent_direction - relative_contact_velocity_f = ptmath.surface_velocity( + relative_contact_velocity_f = surface_velocity( rvw, -normal_direction, ball.params.R ) logger.debug(f"v_c_f={v_c_f}") @@ -127,6 +127,7 @@ class StrongeCompliantLinear(CoreBallLCushionCollision): model: BallLCushionModel = attrs.field( default=BallLCushionModel.STRONGE_COMPLIANT, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: LinearCushionSegment @@ -142,6 +143,7 @@ class StrongeCompliantCircular(CoreBallCCushionCollision): model: BallCCushionModel = attrs.field( default=BallCCushionModel.STRONGE_COMPLIANT, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: CircularCushionSegment diff --git a/pooltool/physics/resolve/ball_cushion/unrealistic/__init__.py b/pooltool/physics/resolve/ball_cushion/unrealistic/__init__.py index cdb9478a..5c457882 100644 --- a/pooltool/physics/resolve/ball_cushion/unrealistic/__init__.py +++ b/pooltool/physics/resolve/ball_cushion/unrealistic/__init__.py @@ -11,6 +11,7 @@ Cushion, LinearCushionSegment, ) +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.ball_cushion.core import ( CoreBallCCushionCollision, CoreBallLCushionCollision, @@ -72,6 +73,7 @@ class UnrealisticLinear(CoreBallLCushionCollision): model: BallLCushionModel = attrs.field( default=BallLCushionModel.UNREALISTIC, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: LinearCushionSegment @@ -85,6 +87,7 @@ class UnrealisticCircular(CoreBallCCushionCollision): model: BallCCushionModel = attrs.field( default=BallCCushionModel.UNREALISTIC, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve( self, ball: Ball, cushion: CircularCushionSegment diff --git a/pooltool/physics/resolve/ball_pocket/__init__.py b/pooltool/physics/resolve/ball_pocket/__init__.py index 94967198..1be56770 100644 --- a/pooltool/physics/resolve/ball_pocket/__init__.py +++ b/pooltool/physics/resolve/ball_pocket/__init__.py @@ -14,12 +14,15 @@ import pooltool.constants as const from pooltool.objects.ball.datatypes import Ball, BallState from pooltool.objects.table.components import Pocket +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.models import BallPocketModel class BallPocketStrategy(Protocol): """Ball-pocket collision models must satisfy this protocol""" + dim: Dim + def resolve( self, ball: Ball, pocket: Pocket, inplace: bool = False ) -> tuple[Ball, Pocket]: @@ -32,6 +35,7 @@ class CanonicalBallPocket: model: BallPocketModel = attrs.field( default=BallPocketModel.CANONICAL, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def resolve( self, ball: Ball, pocket: Pocket, inplace: bool = False diff --git a/pooltool/physics/resolve/sphere_half_space_collision.py b/pooltool/physics/resolve/sphere_half_space_collision.py index 4de28c93..3321680f 100644 --- a/pooltool/physics/resolve/sphere_half_space_collision.py +++ b/pooltool/physics/resolve/sphere_half_space_collision.py @@ -4,6 +4,7 @@ import pooltool.constants as const import pooltool.ptmath as ptmath +from pooltool.physics.utils import tangent_surface_velocity def resolve_sphere_half_space_collision(normal, rvw, R, mu_k, e): @@ -35,7 +36,7 @@ def resolve_sphere_half_space_collision_z_normal(rvw, R, mu_k, e): v_i[2] = 0.0 w_i[2] = 0.0 - v_c_i = ptmath.tangent_surface_velocity(rvw, -unit_z, R) + v_c_i = tangent_surface_velocity(rvw, -unit_z, R) v_c_i_magnitude = ptmath.norm3d(v_c_i) has_relative_velocity = v_c_i_magnitude > const.EPS diff --git a/pooltool/physics/resolve/stick_ball/core.py b/pooltool/physics/resolve/stick_ball/core.py index e79a8862..c2af34da 100644 --- a/pooltool/physics/resolve/stick_ball/core.py +++ b/pooltool/physics/resolve/stick_ball/core.py @@ -3,6 +3,7 @@ from pooltool.objects.ball.datatypes import Ball from pooltool.objects.cue.datatypes import Cue +from pooltool.physics.dimensionality import Dim class _BaseStrategy(Protocol): @@ -14,6 +15,8 @@ def resolve( class StickBallCollisionStrategy(_BaseStrategy, Protocol): """Stick-ball collision models must satisfy this protocol""" + dim: Dim + def solve(self, cue: Cue, ball: Ball) -> tuple[Cue, Ball]: """This method resolves a ball-circular cushion collision""" ... diff --git a/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py b/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py index d287c108..f940c7e7 100644 --- a/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py +++ b/pooltool/physics/resolve/stick_ball/instantaneous_point/__init__.py @@ -5,6 +5,7 @@ import pooltool.ptmath as ptmath from pooltool.objects.ball.datatypes import Ball, BallState from pooltool.objects.cue.datatypes import Cue +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.models import StickBallModel from pooltool.physics.resolve.stick_ball.core import CoreStickBallCollision from pooltool.physics.resolve.stick_ball.squirt import get_squirt_angle @@ -129,6 +130,7 @@ class InstantaneousPoint(CoreStickBallCollision): model: StickBallModel = attrs.field( default=StickBallModel.INSTANTANEOUS_POINT, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def solve(self, cue: Cue, ball: Ball) -> tuple[Cue, Ball]: # Transform contact point Q from cue frame to ball frame diff --git a/pooltool/physics/resolve/transition/__init__.py b/pooltool/physics/resolve/transition/__init__.py index 5c5eee19..3bae61fa 100644 --- a/pooltool/physics/resolve/transition/__init__.py +++ b/pooltool/physics/resolve/transition/__init__.py @@ -13,6 +13,7 @@ import pooltool.constants as const from pooltool.events.datatypes import EventType from pooltool.objects.ball.datatypes import Ball +from pooltool.physics.dimensionality import Dim from pooltool.physics.resolve.models import BallTransitionModel _TOLERANCE = 1e-12 @@ -21,6 +22,8 @@ class BallTransitionStrategy(Protocol): """Ball transition models must satisfy this protocol""" + dim: Dim + def resolve(self, ball: Ball, transition: EventType, inplace: bool = False) -> Ball: """This method resolves a ball transition""" ... @@ -31,6 +34,7 @@ class CanonicalTransition: model: BallTransitionModel = attrs.field( default=BallTransitionModel.CANONICAL, init=False, repr=False ) + dim: Dim = attrs.field(default=Dim.TWO, init=False, repr=False) def resolve(self, ball: Ball, transition: EventType, inplace: bool = False) -> Ball: if not inplace: diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py new file mode 100644 index 00000000..1c86dba8 --- /dev/null +++ b/pooltool/physics/utils.py @@ -0,0 +1,151 @@ +import numpy as np +from numba import jit +from numpy.typing import NDArray + +import pooltool.constants as const +from pooltool.ptmath.utils import coordinate_rotation, cross, norm3d, unit_vector + + +@jit(nopython=True, cache=const.use_numba_cache) +def surface_velocity( + rvw: NDArray[np.float64], d: NDArray[np.float64], R: float +) -> NDArray[np.float64]: + """Compute velocity of a point on ball's surface (specified by unit direction vector)""" + _, v, w = rvw + return v + cross(w, R * d) + + +@jit(nopython=True, cache=const.use_numba_cache) +def tangent_surface_velocity( + rvw: NDArray[np.float64], d: NDArray[np.float64], R: float +) -> NDArray[np.float64]: + """Compute velocity tangent to surface at a point on ball's surface (specified by unit direction vector)""" + _, v, w = rvw + v_t = v - np.sum(v * d) * d + return v_t + cross(w, R * d) + + +@jit(nopython=True, cache=const.use_numba_cache) +def rel_velocity(rvw: NDArray[np.float64], R: float) -> NDArray[np.float64]: + """Compute velocity of ball's point of contact with the cloth relative to the cloth + + This vector is non-zero whenever the ball is sliding + """ + return surface_velocity(rvw, np.array([0.0, 0.0, -1.0], dtype=np.float64), R) + + +@jit(nopython=True, cache=const.use_numba_cache) +def get_u_vec( + rvw: NDArray[np.float64], phi: float, R: float, s: int +) -> NDArray[np.float64]: + if s == const.rolling: + return np.array([1.0, 0.0, 0.0]) + + rel_vel = rel_velocity(rvw, R) + + if (rel_vel == 0.0).all(): + return np.array([1.0, 0.0, 0.0]) + + return coordinate_rotation(unit_vector(rel_vel), -phi) + + +@jit(nopython=True, cache=const.use_numba_cache) +def get_slide_time(rvw: NDArray[np.float64], R: float, u_s: float, g: float) -> float: + if u_s == 0.0: + return np.inf + + return 2 * norm3d(rel_velocity(rvw, R)) / (7 * u_s * g) + + +@jit(nopython=True, cache=const.use_numba_cache) +def get_roll_time(rvw: NDArray[np.float64], u_r: float, g: float) -> float: + if u_r == 0.0: + return np.inf + + _, v, _ = rvw + return norm3d(v) / (u_r * g) + + +@jit(nopython=True, cache=const.use_numba_cache) +def get_spin_time(rvw: NDArray[np.float64], R: float, u_sp: float, g: float) -> float: + if u_sp == 0.0: + return np.inf + + _, _, w = rvw + return np.abs(w[2]) * 2 / 5 * R / u_sp / g + + +def get_ball_energy(rvw: NDArray[np.float64], R: float, m: float) -> float: + """Get the energy of a ball + + Currently calculating linear and rotational kinetic energy. Need to add potential + energy if z-axis is freed + """ + # Linear + LKE = m * norm3d(rvw[1]) ** 2 / 2 + + # Rotational + RKE = (2 / 5 * m * R**2) * norm3d(rvw[2]) ** 2 / 2 + + return LKE + RKE + + +@jit(nopython=True, cache=const.use_numba_cache) +def tip_contact_offset( + cue_center_offset: NDArray[np.float64], tip_radius: float, ball_radius: float +) -> NDArray[np.float64]: + """Calculate the ball contact point offset from the cue tip center offset. + + This function converts the offset of the cue tip's center (relative to the ball's center, + and normalized by the ball's radius) into the offset of the contact point on the ball's surface. + + The conversion is based on the geometry of two circles in contact. Since the distance from the + ball's center to the cue tip's center is (ball_radius + tip_radius) while the ball's surface is + at a distance ball_radius, the contact point lies along the same line scaled by the factor + + 1 / (1 + tip_radius/ball_radius). + + In other words, if (a, b) represent the cue tip center offset, then the ball is struck at + + (a, b) / (1 + tip_radius/ball_radius). + + Args: + cue_center_offset: + A 2D vector (e.g., [a, b]) representing the offset of the cue tip center + relative to the ball center (normalized by the ball's radius). + tip_radius: The radius of the cue tip. + ball_radius: The radius of the ball. + + Returns: + NDArray[np.float64]: + A 2D vector representing the offset of the contact point on the ball's + surface, normalized by the ball's radius. + """ + return cue_center_offset / (1 + tip_radius / ball_radius) + + +@jit(nopython=True, cache=const.use_numba_cache) +def tip_center_offset( + tip_center_offset: NDArray[np.float64], tip_radius: float, ball_radius: float +) -> NDArray[np.float64]: + """Calculate the cue tip center offset from a given contact point offset on the ball. + + This function performs the inverse transformation of `tip_contact_offset`. Given a 2D contact point + offset on the ball’s surface (normalized by the ball's radius), it computes the corresponding cue tip + center offset. Since the cue tip’s center is located an extra tip_radius beyond the ball’s surface, + the transformation scales the contact offset by + + 1 + tip_radius/ball_radius. + + Args: + cue_center_offset: + A 2D vector (e.g., [a, b]) representing the offset of the cue tip center + relative to the ball center (normalized by the ball's radius). + tip_radius: The radius of the cue tip. + ball_radius: The radius of the ball. + + Returns: + NDArray[np.float64]: A 2D vector representing the offset of the cue tip's center relative to the + ball's center (normalized by the ball's radius). + """ + return tip_center_offset * (1 + tip_radius / ball_radius) diff --git a/pooltool/ptmath/__init__.py b/pooltool/ptmath/__init__.py index 73b03e07..5f7ed085 100644 --- a/pooltool/ptmath/__init__.py +++ b/pooltool/ptmath/__init__.py @@ -10,23 +10,15 @@ cross, decompose_normal_tangent, find_intersection_2D, - get_ball_energy, - get_roll_time, - get_slide_time, - get_spin_time, - get_u_vec, is_overlapping, norm2d, norm3d, point_on_line_closest_to_point, quaternion_from_vector_to_vector, - rel_velocity, rotation_from_vector_to_vector, solve_transcendental, squared_norm2d, squared_norm3d, - surface_velocity, - tangent_surface_velocity, unit_vector, unit_vector_slow, wiggle, @@ -51,15 +43,7 @@ "unit_vector", "unit_vector_slow", "wiggle", - "rel_velocity", "rotation_from_vector_to_vector", "quaternion_from_vector_to_vector", - "surface_velocity", - "tangent_surface_velocity", - "get_u_vec", - "get_slide_time", - "get_roll_time", - "get_spin_time", - "get_ball_energy", "is_overlapping", ] diff --git a/pooltool/ptmath/utils.py b/pooltool/ptmath/utils.py index 253288d5..b77b4282 100644 --- a/pooltool/ptmath/utils.py +++ b/pooltool/ptmath/utils.py @@ -324,90 +324,6 @@ def norm2d(vec: NDArray[np.float64]) -> float: return sqrt(squared_norm2d(vec)) -@jit(nopython=True, cache=const.use_numba_cache) -def surface_velocity( - rvw: NDArray[np.float64], d: NDArray[np.float64], R: float -) -> NDArray[np.float64]: - """Compute velocity of a point on ball's surface (specified by unit direction vector)""" - _, v, w = rvw - return v + cross(w, R * d) - - -@jit(nopython=True, cache=const.use_numba_cache) -def tangent_surface_velocity( - rvw: NDArray[np.float64], d: NDArray[np.float64], R: float -) -> NDArray[np.float64]: - """Compute velocity tangent to surface at a point on ball's surface (specified by unit direction vector)""" - _, v, w = rvw - v_t = v - np.sum(v * d) * d - return v_t + cross(w, R * d) - - -@jit(nopython=True, cache=const.use_numba_cache) -def rel_velocity(rvw: NDArray[np.float64], R: float) -> NDArray[np.float64]: - """Compute velocity of ball's point of contact with the cloth relative to the cloth - - This vector is non-zero whenever the ball is sliding - """ - return surface_velocity(rvw, np.array([0.0, 0.0, -1.0], dtype=np.float64), R) - - -@jit(nopython=True, cache=const.use_numba_cache) -def get_u_vec( - rvw: NDArray[np.float64], phi: float, R: float, s: int -) -> NDArray[np.float64]: - if s == const.rolling: - return np.array([1.0, 0.0, 0.0]) - - rel_vel = rel_velocity(rvw, R) - - if (rel_vel == 0.0).all(): - return np.array([1.0, 0.0, 0.0]) - - return coordinate_rotation(unit_vector(rel_vel), -phi) - - -@jit(nopython=True, cache=const.use_numba_cache) -def get_slide_time(rvw: NDArray[np.float64], R: float, u_s: float, g: float) -> float: - if u_s == 0.0: - return np.inf - - return 2 * norm3d(rel_velocity(rvw, R)) / (7 * u_s * g) - - -@jit(nopython=True, cache=const.use_numba_cache) -def get_roll_time(rvw: NDArray[np.float64], u_r: float, g: float) -> float: - if u_r == 0.0: - return np.inf - - _, v, _ = rvw - return norm3d(v) / (u_r * g) - - -@jit(nopython=True, cache=const.use_numba_cache) -def get_spin_time(rvw: NDArray[np.float64], R: float, u_sp: float, g: float) -> float: - if u_sp == 0.0: - return np.inf - - _, _, w = rvw - return np.abs(w[2]) * 2 / 5 * R / u_sp / g - - -def get_ball_energy(rvw: NDArray[np.float64], R: float, m: float) -> float: - """Get the energy of a ball - - Currently calculating linear and rotational kinetic energy. Need to add potential - energy if z-axis is freed - """ - # Linear - LKE = m * norm3d(rvw[1]) ** 2 / 2 - - # Rotational - RKE = (2 / 5 * m * R**2) * norm3d(rvw[2]) ** 2 / 2 - - return LKE + RKE - - def is_overlapping( rvw1: NDArray[np.float64], rvw2: NDArray[np.float64], @@ -416,64 +332,3 @@ def is_overlapping( min_spacer: float = 0.0, ) -> bool: return norm3d(rvw1[0] - rvw2[0]) < (R1 + R2 + min_spacer) - - -@jit(nopython=True, cache=const.use_numba_cache) -def tip_contact_offset( - cue_center_offset: NDArray[np.float64], tip_radius: float, ball_radius: float -) -> NDArray[np.float64]: - """Calculate the ball contact point offset from the cue tip center offset. - - This function converts the offset of the cue tip's center (relative to the ball's center, - and normalized by the ball's radius) into the offset of the contact point on the ball's surface. - - The conversion is based on the geometry of two circles in contact. Since the distance from the - ball's center to the cue tip's center is (ball_radius + tip_radius) while the ball's surface is - at a distance ball_radius, the contact point lies along the same line scaled by the factor - - 1 / (1 + tip_radius/ball_radius). - - In other words, if (a, b) represent the cue tip center offset, then the ball is struck at - - (a, b) / (1 + tip_radius/ball_radius). - - Args: - cue_center_offset: - A 2D vector (e.g., [a, b]) representing the offset of the cue tip center - relative to the ball center (normalized by the ball's radius). - tip_radius: The radius of the cue tip. - ball_radius: The radius of the ball. - - Returns: - NDArray[np.float64]: - A 2D vector representing the offset of the contact point on the ball's - surface, normalized by the ball's radius. - """ - return cue_center_offset / (1 + tip_radius / ball_radius) - - -@jit(nopython=True, cache=const.use_numba_cache) -def tip_center_offset( - tip_center_offset: NDArray[np.float64], tip_radius: float, ball_radius: float -) -> NDArray[np.float64]: - """Calculate the cue tip center offset from a given contact point offset on the ball. - - This function performs the inverse transformation of `tip_contact_offset`. Given a 2D contact point - offset on the ball’s surface (normalized by the ball's radius), it computes the corresponding cue tip - center offset. Since the cue tip’s center is located an extra tip_radius beyond the ball’s surface, - the transformation scales the contact offset by - - 1 + tip_radius/ball_radius. - - Args: - cue_center_offset: - A 2D vector (e.g., [a, b]) representing the offset of the cue tip center - relative to the ball center (normalized by the ball's radius). - tip_radius: The radius of the cue tip. - ball_radius: The radius of the ball. - - Returns: - NDArray[np.float64]: A 2D vector representing the offset of the cue tip's center relative to the - ball's center (normalized by the ball's radius). - """ - return tip_center_offset * (1 + tip_radius / ball_radius) diff --git a/tests/evolution/event_based/test_detector.py b/tests/evolution/event_based/test_detector.py new file mode 100644 index 00000000..15e9600e --- /dev/null +++ b/tests/evolution/event_based/test_detector.py @@ -0,0 +1,69 @@ +import pytest + +from pooltool.events import ( + Event, + ball_ball_collision, + ball_circular_cushion_collision, + ball_linear_cushion_collision, + ball_pocket_collision, + null_event, + sliding_rolling_transition, + stick_ball_collision, +) +from pooltool.evolution.event_based.detect.detector import _get_event_priority +from pooltool.objects import Ball, Cue, Table +from pooltool.system import System + + +@pytest.fixture +def system() -> System: + return System( + cue=Cue(cue_ball_id="cue"), + table=Table.default(), + balls=( + Ball.create("cue", xy=(0.5, 0.5)), + Ball.create("1", xy=(0.7, 0.5)), + Ball.create("2", xy=(0.9, 0.5)), + ), + ) + + +def _make_events(s: System) -> dict[str, Event]: + """One event of each type, all at the same simulation time.""" + return { + "stick_ball": stick_ball_collision(stick=s.cue, ball=s.balls["cue"], time=0), + "pocket": ball_pocket_collision( + ball=s.balls["cue"], + pocket=next(iter(s.table.pockets.values())), + time=0, + ), + "transition": sliding_rolling_transition(s.balls["cue"], time=0), + "ball_ball": ball_ball_collision(s.balls["cue"], s.balls["1"], time=0), + "linear_cushion": ball_linear_cushion_collision( + ball=s.balls["cue"], + cushion=next(iter(s.table.cushion_segments.linear.values())), + time=0, + ), + "circular_cushion": ball_circular_cushion_collision( + ball=s.balls["cue"], + cushion=next(iter(s.table.cushion_segments.circular.values())), + time=0, + ), + "none": null_event(time=0), + } + + +def test_event_priority_sorts_by_tier(system): + """All event types at the same time sort in priority order.""" + events = _make_events(system) + + def sort_key(event: Event) -> tuple[int, float]: + tier, energy = _get_event_priority(event, system) + return (tier, -energy) + + sorted_events = sorted(events.values(), key=sort_key) + tiers = [_get_event_priority(e, system)[0] for e in sorted_events] + + assert tiers == sorted(tiers), "tiers should be non-decreasing" + assert sorted_events[0] is events["stick_ball"] + assert sorted_events[-1] is events["none"] diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index 16aa3ebb..2467c131 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -3,24 +3,24 @@ from numpy.typing import NDArray import pooltool.constants as const +import pooltool.physics as physics import pooltool.ptmath as ptmath from pooltool import aim, events from pooltool.events import EventType, ball_ball_collision, ball_pocket_collision +from pooltool.evolution.event_based._utils import _system_has_energy from pooltool.evolution.event_based.cache import CollisionCache -from pooltool.evolution.event_based.simulate import ( - _system_has_energy, - get_next_ball_ball_collision, - get_next_event, - simulate, -) -from pooltool.evolution.event_based.solve import ball_ball_collision_time +from pooltool.evolution.event_based.detect import BallBallDetection, EventDetector +from pooltool.evolution.event_based.simulate import simulate from pooltool.objects import Ball, BilliardTableSpecs, Cue, Table from pooltool.objects.ball.params import BallParams from pooltool.objects.ball.sets import BallSet +from pooltool.physics.motion.solve import ball_ball_collision_time from pooltool.ptmath.roots import quadratic from pooltool.system import System from tests.evolution.event_based.test_data import TEST_DIR +_DETECTOR = EventDetector() + def test_simulate_inplace(): # First, we don't modify in place @@ -92,7 +92,7 @@ def test_case1(): ball.state = ball.history[0] shot.reset_history() - next_event = get_next_event(shot) + next_event = _DETECTOR.get_next_event(shot) expected = ball_ball_collision( shot.balls["1"], shot.balls["cue"], 0.048943195217641386 @@ -113,7 +113,7 @@ def test_case2(): """ shot = System.load(TEST_DIR / "case2.msgpack") - next_event = get_next_event(shot) + next_event = _DETECTOR.get_next_event(shot) expected = ball_pocket_collision( shot.balls["8"], shot.table.pockets["lc"], 0.08933033587481054 @@ -148,7 +148,7 @@ def test_case3(): ball1 = shot.balls["2"] ball2 = shot.balls["5"] - event = get_next_event(shot) + event = _DETECTOR.get_next_event(shot) expected = pytest.approx(5.810383731499328e-06, abs=1e-9) calculated = ball_ball_collision_time( @@ -218,7 +218,7 @@ def test_case4(): def _assert_rolling(rvw: NDArray[np.float64], R: float) -> None: - assert np.isclose(ptmath.rel_velocity(rvw, R), 0).all() + assert np.isclose(physics.rel_velocity(rvw, R), 0).all() def test_grazing_ball_ball_collision(): @@ -448,8 +448,8 @@ def test_ball_ball_collision_for_intersecting_balls(): # The cue is truly rolling _assert_rolling(system.balls["cue"].state.rvw, system.balls["cue"].params.R) - assert get_next_event(system).event_type == EventType.BALL_BALL - collision_event = get_next_ball_ball_collision(system, CollisionCache()) + assert _DETECTOR.get_next_event(system).event_type == EventType.BALL_BALL + collision_event = BallBallDetection().get_next(system, CollisionCache()) assert collision_event.time != np.inf assert collision_event.time == 0 @@ -514,14 +514,15 @@ def test_stick_ball_event_detection(): - No ball energy (all stationary) - Cue with V0 > 0 (ready to strike) - The stick-ball collision should be detected as the next event by get_next_event(). - This event should be: + The stick-ball collision should be detected as the next event by + EventDetector.get_next_event(). This event should be: - At time t=0 - Type STICK_BALL - Processed through the normal event resolution pipeline This validates the refactor that moved stick-ball detection from initialization - into get_next_event(), treating it as a first-class event rather than a special case. + into EventDetector.get_next_event(), treating it as a first-class event rather + than a special case. """ system = System.example() @@ -529,7 +530,7 @@ def test_stick_ball_event_detection(): assert not _system_has_energy(system) assert system.cue.V0 > 0 - event = get_next_event(system) + event = _DETECTOR.get_next_event(system) assert event.event_type == EventType.STICK_BALL assert event.time == 0 diff --git a/tests/evolution/test_engine.py b/tests/evolution/test_engine.py new file mode 100644 index 00000000..5351dd6f --- /dev/null +++ b/tests/evolution/test_engine.py @@ -0,0 +1,87 @@ +import attrs +import pytest + +from pooltool.evolution.engine import SimulationEngine +from pooltool.evolution.event_based.detect import EventDetector +from pooltool.physics.dimensionality import Dim +from pooltool.physics.resolve.resolver import Resolver + + +def _patch_all_dims( + resolver: Resolver, + detector: EventDetector, + dim_value: Dim, +) -> None: + """Set every strategy's dim to dim_value, in place.""" + for bundle in (resolver, detector): + for field in attrs.fields(type(bundle)): + strategy = getattr(bundle, field.name) + if hasattr(strategy, "dim"): + strategy.dim = dim_value + + +@pytest.fixture +def engine_3d() -> SimulationEngine: + """A SimulationEngine constructed with every strategy patched to Dim.THREE.""" + resolver = SimulationEngine().resolver + detector = SimulationEngine().detector + _patch_all_dims(resolver, detector, Dim.THREE) + return SimulationEngine(resolver=resolver, detector=detector, is_3d=True) + + +def test_default_engine_constructs(): + engine = SimulationEngine() + assert engine.is_3d is False + + +def test_3d_engine_constructs(engine_3d: SimulationEngine): + assert engine_3d.is_3d is True + + +def test_3d_engine_with_all_2d_strategies_raises(): + with pytest.raises(ValueError, match="incompatible with is_3d=True"): + SimulationEngine(is_3d=True) + + +def test_validation_error_identifies_offending_strategy(): + with pytest.raises(ValueError, match=r"Resolver\.|EventDetector\."): + SimulationEngine(is_3d=True) + + +def test_strategy_missing_dim_raises(): + @attrs.define + class DummyBallBall: + pass + + resolver = SimulationEngine().resolver + resolver.ball_ball = DummyBallBall() # type: ignore + + with pytest.raises(AttributeError, match="missing required 'dim'"): + SimulationEngine(resolver=resolver) + + +def test_3d_engine_rejects_one_dim_two_strategy(engine_3d: SimulationEngine): + """Reverting one strategy to Dim.TWO causes validation to raise, naming it.""" + engine_3d.resolver.ball_ball.dim = Dim.TWO + + with pytest.raises(ValueError, match=r"ball_ball.*incompatible with is_3d=True"): + SimulationEngine( + resolver=engine_3d.resolver, + detector=engine_3d.detector, + is_3d=True, + ) + + +def test_dim_both_strategy_accepted_in_2d(): + resolver = SimulationEngine().resolver + resolver.ball_ball.dim = Dim.BOTH + SimulationEngine(resolver=resolver, is_3d=False) + + +def test_dim_both_strategy_accepted_in_3d(engine_3d: SimulationEngine): + engine_3d.resolver.ball_ball.dim = Dim.BOTH + SimulationEngine( + resolver=engine_3d.resolver, + detector=engine_3d.detector, + is_3d=True, + ) diff --git a/tests/physics/resolve/ball_ball/test_ball_ball.py b/tests/physics/resolve/ball_ball/test_ball_ball.py index 6d8815df..4e30f6db 100644 --- a/tests/physics/resolve/ball_ball/test_ball_ball.py +++ b/tests/physics/resolve/ball_ball/test_ball_ball.py @@ -10,6 +10,7 @@ from pooltool.physics.resolve.ball_ball.frictional_inelastic import FrictionalInelastic from pooltool.physics.resolve.ball_ball.frictional_mathavan import FrictionalMathavan from pooltool.physics.resolve.ball_ball.frictionless_elastic import FrictionlessElastic +from pooltool.physics.utils import tangent_surface_velocity def vector_from_magnitude_and_direction(magnitude: float, angle_radians: float): @@ -189,7 +190,7 @@ def test_gearing_z_spin( ) # sanity check the initial conditions - v_c = ptmath.tangent_surface_velocity(cb_i.state.rvw, unit_normal, cb_i.params.R) + v_c = tangent_surface_velocity(cb_i.state.rvw, unit_normal, cb_i.params.R) assert ptmath.norm3d(v_c) < 1e-10, "Relative surface contact speed should be zero" cb_f, ob_f = model.resolve(cb_i, ob_i, inplace=False) @@ -239,19 +240,15 @@ def test_low_relative_surface_velocity( ) # from v = w * R -> w = v / R # sanity check the initial conditions - v_c = ptmath.tangent_surface_velocity(cb_i.state.rvw, unit_normal, cb_i.params.R) + v_c = tangent_surface_velocity(cb_i.state.rvw, unit_normal, cb_i.params.R) assert abs(relative_surface_speed - ptmath.norm3d(v_c)) < 1e-10, ( f"Relative surface contact speed should be {relative_surface_speed}" ) cb_f, ob_f = model.resolve(cb_i, ob_i, inplace=False) - cb_v_c_f = ptmath.tangent_surface_velocity( - cb_f.state.rvw, unit_normal, cb_f.params.R - ) - ob_v_c_f = ptmath.tangent_surface_velocity( - ob_f.state.rvw, -unit_normal, ob_f.params.R - ) + cb_v_c_f = tangent_surface_velocity(cb_f.state.rvw, unit_normal, cb_f.params.R) + ob_v_c_f = tangent_surface_velocity(ob_f.state.rvw, -unit_normal, ob_f.params.R) assert ptmath.norm3d(cb_v_c_f - ob_v_c_f) < 1e-3, ( "Final relative contact velocity should be zero" ) diff --git a/tests/physics/resolve/ball_ball/test_frictional_mathavan.py b/tests/physics/resolve/ball_ball/test_frictional_mathavan.py index b7325cae..38340a20 100644 --- a/tests/physics/resolve/ball_ball/test_frictional_mathavan.py +++ b/tests/physics/resolve/ball_ball/test_frictional_mathavan.py @@ -56,9 +56,9 @@ def test_collide_balls(initial_conditions, expected): def calc_rolling_velocity(v, w): rvw = np.zeros((3, 3), dtype=np.float64) rvw[1], rvw[2] = v, w - u = pt.ptmath.rel_velocity(rvw, R) + u = pt.physics.rel_velocity(rvw, R) a = -mu_s * g * u / np.linalg.norm(u) - return v + a * pt.ptmath.get_slide_time(rvw, R, mu_s, g) + return v + a * pt.physics.get_slide_time(rvw, R, mu_s, g) v_iS = calc_rolling_velocity(v_i1, w_i1) v_jS = calc_rolling_velocity(v_j1, w_j1) diff --git a/tests/physics/resolve/ball_cushion/test_ball_cushion.py b/tests/physics/resolve/ball_cushion/test_ball_cushion.py index e590e008..488381ae 100644 --- a/tests/physics/resolve/ball_cushion/test_ball_cushion.py +++ b/tests/physics/resolve/ball_cushion/test_ball_cushion.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from pooltool import ptmath +from pooltool import physics, ptmath from pooltool.constants import sliding, stationary from pooltool.objects import ( Ball, @@ -71,7 +71,7 @@ def test_energy( ball.state.rvw[1] = vel ball.state.s = sliding - initial_energy = ptmath.get_ball_energy( + initial_energy = physics.get_ball_energy( ball.state.rvw, ball.params.R, ball.params.m, @@ -81,7 +81,7 @@ def test_energy( model = ball_lcushion_models[model_name]() ball_after, _ = model.resolve(ball=ball, cushion=cushion_yaxis, inplace=False) - final_energy = ptmath.get_ball_energy( + final_energy = physics.get_ball_energy( ball_after.state.rvw, ball_after.params.R, ball_after.params.m, diff --git a/tests/ptmath/test_ptmath_surface_velocity.py b/tests/physics/test_utils.py similarity index 98% rename from tests/ptmath/test_ptmath_surface_velocity.py rename to tests/physics/test_utils.py index 09030b8e..232e9e85 100644 --- a/tests/ptmath/test_ptmath_surface_velocity.py +++ b/tests/physics/test_utils.py @@ -1,7 +1,7 @@ import numpy as np import pytest -from pooltool.ptmath.utils import surface_velocity, tangent_surface_velocity +from pooltool.physics.utils import surface_velocity, tangent_surface_velocity def test_surface_velocity_no_angular_velocity():