Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions pooltool/physics/motion/solve.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from math import acos, isnan
from math import acos

import numpy as np
from numba import jit
Expand Down Expand Up @@ -254,7 +254,7 @@ def ball_linear_cushion_collision_time(

min_time = np.inf
for root in roots:
if isnan(root):
if np.isnan(root):
continue

if np.abs(root.imag) > const.EPS:
Expand All @@ -263,7 +263,7 @@ def ball_linear_cushion_collision_time(
if root.real <= const.EPS:
continue

rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root)
rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real)
s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1)

if not (0 <= s_score <= 1):
Expand Down
2 changes: 1 addition & 1 deletion pooltool/physics/resolve/ball_ball/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,7 @@ def make_kiss(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]:
+ Cz * Cz
- (2 * ball1.params.R + spacer) * (2 * ball1.params.R + spacer)
)
roots_complex = ptmath.roots.quadratic.solve_complex(alpha, beta, gamma)
roots_complex = ptmath.roots.quadratic.solve(alpha, beta, gamma)

imag_mag = np.abs(roots_complex.imag)
real_mag = np.abs(roots_complex.real)
Expand Down
2 changes: 1 addition & 1 deletion pooltool/physics/resolve/ball_cushion/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ def make_kiss(self, ball: Ball, cushion: CircularCushionSegment) -> Ball:
beta = 2 * (diff[0] * v[0] + diff[1] * v[1])
gamma = diff[0] ** 2 + diff[1] ** 2 - target**2

roots_complex = ptmath.roots.quadratic.solve_complex(alpha, beta, gamma)
roots_complex = ptmath.roots.quadratic.solve(alpha, beta, gamma)

imag_mag = np.abs(roots_complex.imag)
real_mag = np.abs(roots_complex.real)
Expand Down
11 changes: 2 additions & 9 deletions pooltool/physics/utils.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
import math

import numpy as np
from numba import jit
from numpy.typing import NDArray
Expand Down Expand Up @@ -57,18 +55,13 @@ def get_airborne_time(rvw: NDArray[np.float64], R: float, g: float) -> float:
"""Time until an airborne ball's bottom touches the table plane (z = R).

Solves ``-0.5 * g * t**2 + v_z * t + (z - R) = 0`` and returns the later root
(the descending-leg intersection). Returns ``np.inf`` when gravity is zero, or
when the discriminant is negative.
(the descending-leg intersection). Returns ``np.inf`` when gravity is zero.
"""
if g == 0.0:
return np.inf

t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R)

if math.isnan(t1):
return np.inf

return max(t1, t2)
return max(t1.real, t2.real)
Comment on lines 63 to +64
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Preserve no-real-root handling for airborne touchdown time.

Line 64 currently returns a finite time from complex roots by taking .real, which can schedule a non-physical touchdown when there is no real intersection.

Proposed fix
-    t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R)
-    return max(t1.real, t2.real)
+    t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R)
+    t1_is_real = np.abs(t1.imag) <= const.EPS
+    t2_is_real = np.abs(t2.imag) <= const.EPS
+
+    if not t1_is_real and not t2_is_real:
+        return np.inf
+
+    r1 = t1.real if t1_is_real else -np.inf
+    r2 = t2.real if t2_is_real else -np.inf
+    return max(r1, r2)
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R)
if math.isnan(t1):
return np.inf
return max(t1, t2)
return max(t1.real, t2.real)
t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R)
t1_is_real = np.abs(t1.imag) <= const.EPS
t2_is_real = np.abs(t2.imag) <= const.EPS
if not t1_is_real and not t2_is_real:
return np.inf
r1 = t1.real if t1_is_real else -np.inf
r2 = t2.real if t2_is_real else -np.inf
return max(r1, r2)
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@pooltool/physics/utils.py` around lines 63 - 64, The code takes the .real
part of roots from quadratic.solve and can return a non-physical touchdown time
when roots are complex; instead, after calling quadratic.solve (variables t1,
t2), check whether the roots are real (e.g., test abs(t1.imag) and abs(t2.imag)
against a tiny tolerance or use numpy.isreal); if both are real return
max(t1.real, t2.real) as before, otherwise preserve the prior “no real root”
behavior by returning None (or the module’s sentinel for “no touchdown”) rather
than coercing complex values to real.



@jit(nopython=True, cache=const.use_numba_cache)
Expand Down
20 changes: 1 addition & 19 deletions pooltool/ptmath/roots/quadratic.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import cmath
import math

import numpy as np
from numba import jit
Expand All @@ -9,24 +8,7 @@


@jit(nopython=True, cache=const.use_numba_cache)
def solve(a: float, b: float, c: float) -> tuple[float, float]:
"""Solve a quadratic equation :math:`A t^2 + B t + C = 0` (just-in-time compiled)"""
if np.abs(a) < const.EPS:
if np.abs(b) < const.EPS:
return math.nan, math.nan
u = -c / b
return u, u
bp = b / 2
delta = bp * bp - a * c
u1 = (-bp - delta**0.5) / a
u2 = -u1 - b / a
return u1, u2


# TODO: In the branch `3d`, which will eventually be merged into main, this function has
# replaced `solve`.
@jit(nopython=True, cache=const.use_numba_cache)
def solve_complex(a: float, b: float, c: float) -> NDArray[np.complex128]:
def solve(a: float, b: float, c: float) -> NDArray[np.complex128]:
_a = complex(a)
_b = complex(b)
_c = complex(c)
Expand Down
4 changes: 2 additions & 2 deletions tests/evolution/event_based/test_simulate.py
Original file line number Diff line number Diff line change
Expand Up @@ -370,8 +370,8 @@ def true_time_to_collision(eps, V0, mu_r, g):
"""
collision_time = np.inf
for t in quadratic.solve(0.5 * mu_r * g, -V0, eps):
if t >= 0 and t < collision_time:
collision_time = t
if t.real >= 0 and t.real < collision_time:
collision_time = t.real
Comment on lines 372 to +374
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

⚠️ Potential issue | 🟠 Major | ⚡ Quick win

Validate that roots are real before using their real component.

The code extracts t.real without verifying that the imaginary part is negligible. For a collision-time calculation, only real-valued roots are physically meaningful. If the solver returns a root with a significant imaginary component (e.g., 2.0 + 0.1j), the current code would incorrectly use 2.0 as a valid collision time.

Since this helper computes the ground-truth collision time for test validation, it should be as strict as production code about root validity.

🛡️ Proposed fix to filter complex roots
     collision_time = np.inf
     for t in quadratic.solve(0.5 * mu_r * g, -V0, eps):
-        if t.real >= 0 and t.real < collision_time:
+        if not np.isnan(t) and np.isclose(t.imag, 0.0, atol=1e-9) and t.real >= 0 and t.real < collision_time:
             collision_time = t.real
     return collision_time

Alternatively, to match the pattern in other physics modules, check the imaginary part separately:

     collision_time = np.inf
     for t in quadratic.solve(0.5 * mu_r * g, -V0, eps):
-        if t.real >= 0 and t.real < collision_time:
+        if np.isnan(t) or not np.isclose(t.imag, 0.0, atol=1e-9):
+            continue
+        if t.real >= 0 and t.real < collision_time:
             collision_time = t.real
     return collision_time
📝 Committable suggestion

‼️ IMPORTANT
Carefully review the code before committing. Ensure that it accurately replaces the highlighted code, contains no missing lines, and has no issues with indentation. Thoroughly test & benchmark the code to ensure it meets the requirements.

Suggested change
for t in quadratic.solve(0.5 * mu_r * g, -V0, eps):
if t >= 0 and t < collision_time:
collision_time = t
if t.real >= 0 and t.real < collision_time:
collision_time = t.real
for t in quadratic.solve(0.5 * mu_r * g, -V0, eps):
if not np.isnan(t) and np.isclose(t.imag, 0.0, atol=1e-9) and t.real >= 0 and t.real < collision_time:
collision_time = t.real
🤖 Prompt for AI Agents
Verify each finding against current code. Fix only still-valid issues, skip the
rest with a brief reason, keep changes minimal, and validate.

In `@tests/evolution/event_based/test_simulate.py` around lines 372 - 374, The
loop that uses quadratic.solve to pick a collision_time currently reads t.real
without ensuring the root is effectively real; modify the loop that iterates
over quadratic.solve(...) and only consider roots whose imaginary part is
negligible (e.g., abs(t.imag) < a small tolerance like 1e-9) before using
t.real, and ignore complex roots with significant imaginary components so
collision_time is set only from physically meaningful real roots.

return collision_time

V0 = 2
Expand Down
103 changes: 103 additions & 0 deletions tests/ptmath/roots/test_quadratic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,103 @@
import numpy as np
import pytest

from pooltool.ptmath.roots.quadratic import solve


def test_solve_standard_quadratic():
# x^2 - 5x + 6 = 0
# Solutions: x = 2, x = 3
u1, u2 = solve(1.0, -5.0, 6.0)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> 2.0 + 0.0j
assert solutions[0].real == 2.0
assert solutions[0].imag == 0.0
# Second root -> 3.0 + 0.0j
assert solutions[1].real == 3.0
assert solutions[1].imag == 0.0

# x^2 - x - 2 = 0
# Solutions: x = -1, x = 2
u1, u2 = solve(1.0, -1.0, -2.0)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> -1.0 + 0.0j
assert solutions[0].real == -1.0
assert solutions[0].imag == 0.0
# Second root -> 2.0 + 0.0j
assert solutions[1].real == 2.0
assert solutions[1].imag == 0.0

# Perfect square: x^2 - 4x + 4 = 0
# Single repeated solution: x = 2
u1, u2 = solve(1.0, -4.0, 4.0)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# Both roots -> 2.0 + 0.0j
for root in solutions:
assert root.real == 2.0
assert root.imag == 0.0

# Difference of squares: x^2 - 4 = 0
# Solutions: x = -2, x = 2
u1, u2 = solve(1.0, 0.0, -4.0)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> -2.0 + 0.0j
assert solutions[0].real == -2.0
assert solutions[0].imag == 0.0
# Second root -> 2.0 + 0.0j
assert solutions[1].real == 2.0
assert solutions[1].imag == 0.0

# Complex roots: x^2 + 1 = 0
# Solutions: x = i, x = -i
u1, u2 = solve(1.0, 0.0, 1.0)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))
# First root -> -i -> (0.0, -1.0)
assert solutions[0].real == 0.0
assert solutions[0].imag == -1.0
# Second root -> i -> (0.0, 1.0)
assert solutions[1].real == 0.0
assert solutions[1].imag == 1.0


def test_solve_large_values():
"""Test large coefficients for numerical stability."""

# Equation: x^2 - 1e7*x + 1 = 0
# This should give one very large and one very small solution.
a, b, c = 1.0, -1e7, 1.0
u1, u2 = solve(a, b, c)
solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag))

# The large root should be close to 1e7, the smaller should be close to 1e-7. We're
# able to use a very small relative tolerance due to the way the solver avoids
# catastrophic cancellation.
assert pytest.approx(solutions[0].real, rel=1e-12) == 1e-7
assert pytest.approx(solutions[1].real, rel=1e-12) == 1e7

assert solutions[0].imag == 0.0
assert solutions[1].imag == 0.0


def test_solve_linear_equation():
# a=0, b≠0 => linear equation b*t + c = 0 => t=-c/b
# e.g. 2t + 4 = 0 => t=-2
r1, r2 = solve(0.0, 2.0, 4.0)
assert r1.real == -2.0
assert r1.imag == 0.0
assert np.isnan(r2)


def test_solve_degenerate_no_solution():
# a=0, b=0, c≠0 => no solutions
# e.g. 0*t^2 + 0*t + 5 = 0 => no real solution
r1, r2 = solve(0.0, 0.0, 5.0)
assert np.isnan(r1)
assert np.isnan(r2)


def test_solve_degenerate_infinite_solutions():
# a=0, b=0, c=0 => infinite solutions
# e.g. 0*t^2 + 0*t + 0 = 0 => t can be anything
r1, r2 = solve(0.0, 0.0, 0.0)
assert np.isnan(r1)
assert np.isnan(r2)
Loading