diff --git a/pooltool/physics/motion/solve.py b/pooltool/physics/motion/solve.py index af2e4b4e..6b1af5c9 100644 --- a/pooltool/physics/motion/solve.py +++ b/pooltool/physics/motion/solve.py @@ -1,4 +1,4 @@ -from math import acos, isnan +from math import acos import numpy as np from numba import jit @@ -254,7 +254,7 @@ def ball_linear_cushion_collision_time( min_time = np.inf for root in roots: - if isnan(root): + if np.isnan(root): continue if np.abs(root.imag) > const.EPS: @@ -263,7 +263,7 @@ def ball_linear_cushion_collision_time( if root.real <= const.EPS: continue - rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root) + rvw_dtau, _ = evolve.evolve_ball_motion(s, rvw, R, m, mu, 1, mu, g, root.real) s_score = -np.dot(p1 - rvw_dtau[0], p2 - p1) / np.dot(p2 - p1, p2 - p1) if not (0 <= s_score <= 1): diff --git a/pooltool/physics/resolve/ball_ball/core.py b/pooltool/physics/resolve/ball_ball/core.py index 2c6be6da..49f8db67 100644 --- a/pooltool/physics/resolve/ball_ball/core.py +++ b/pooltool/physics/resolve/ball_ball/core.py @@ -107,7 +107,7 @@ def make_kiss(self, ball1: Ball, ball2: Ball) -> tuple[Ball, Ball]: + Cz * Cz - (2 * ball1.params.R + spacer) * (2 * ball1.params.R + spacer) ) - roots_complex = ptmath.roots.quadratic.solve_complex(alpha, beta, gamma) + roots_complex = ptmath.roots.quadratic.solve(alpha, beta, gamma) imag_mag = np.abs(roots_complex.imag) real_mag = np.abs(roots_complex.real) diff --git a/pooltool/physics/resolve/ball_cushion/core.py b/pooltool/physics/resolve/ball_cushion/core.py index 49ca1319..3d4a89bb 100644 --- a/pooltool/physics/resolve/ball_cushion/core.py +++ b/pooltool/physics/resolve/ball_cushion/core.py @@ -174,7 +174,7 @@ def make_kiss(self, ball: Ball, cushion: CircularCushionSegment) -> Ball: beta = 2 * (diff[0] * v[0] + diff[1] * v[1]) gamma = diff[0] ** 2 + diff[1] ** 2 - target**2 - roots_complex = ptmath.roots.quadratic.solve_complex(alpha, beta, gamma) + roots_complex = ptmath.roots.quadratic.solve(alpha, beta, gamma) imag_mag = np.abs(roots_complex.imag) real_mag = np.abs(roots_complex.real) diff --git a/pooltool/physics/utils.py b/pooltool/physics/utils.py index 125d3f7b..dcbef489 100644 --- a/pooltool/physics/utils.py +++ b/pooltool/physics/utils.py @@ -1,5 +1,3 @@ -import math - import numpy as np from numba import jit from numpy.typing import NDArray @@ -57,18 +55,13 @@ def get_airborne_time(rvw: NDArray[np.float64], R: float, g: float) -> float: """Time until an airborne ball's bottom touches the table plane (z = R). Solves ``-0.5 * g * t**2 + v_z * t + (z - R) = 0`` and returns the later root - (the descending-leg intersection). Returns ``np.inf`` when gravity is zero, or - when the discriminant is negative. + (the descending-leg intersection). Returns ``np.inf`` when gravity is zero. """ if g == 0.0: return np.inf t1, t2 = quadratic.solve(-0.5 * g, rvw[1, 2], rvw[0, 2] - R) - - if math.isnan(t1): - return np.inf - - return max(t1, t2) + return max(t1.real, t2.real) @jit(nopython=True, cache=const.use_numba_cache) diff --git a/pooltool/ptmath/roots/quadratic.py b/pooltool/ptmath/roots/quadratic.py index 0c770832..718418a4 100644 --- a/pooltool/ptmath/roots/quadratic.py +++ b/pooltool/ptmath/roots/quadratic.py @@ -1,5 +1,4 @@ import cmath -import math import numpy as np from numba import jit @@ -9,24 +8,7 @@ @jit(nopython=True, cache=const.use_numba_cache) -def solve(a: float, b: float, c: float) -> tuple[float, float]: - """Solve a quadratic equation :math:`A t^2 + B t + C = 0` (just-in-time compiled)""" - if np.abs(a) < const.EPS: - if np.abs(b) < const.EPS: - return math.nan, math.nan - u = -c / b - return u, u - bp = b / 2 - delta = bp * bp - a * c - u1 = (-bp - delta**0.5) / a - u2 = -u1 - b / a - return u1, u2 - - -# TODO: In the branch `3d`, which will eventually be merged into main, this function has -# replaced `solve`. -@jit(nopython=True, cache=const.use_numba_cache) -def solve_complex(a: float, b: float, c: float) -> NDArray[np.complex128]: +def solve(a: float, b: float, c: float) -> NDArray[np.complex128]: _a = complex(a) _b = complex(b) _c = complex(c) diff --git a/tests/evolution/event_based/test_simulate.py b/tests/evolution/event_based/test_simulate.py index 2467c131..ea627052 100644 --- a/tests/evolution/event_based/test_simulate.py +++ b/tests/evolution/event_based/test_simulate.py @@ -370,8 +370,8 @@ def true_time_to_collision(eps, V0, mu_r, g): """ collision_time = np.inf for t in quadratic.solve(0.5 * mu_r * g, -V0, eps): - if t >= 0 and t < collision_time: - collision_time = t + if t.real >= 0 and t.real < collision_time: + collision_time = t.real return collision_time V0 = 2 diff --git a/tests/ptmath/roots/test_quadratic.py b/tests/ptmath/roots/test_quadratic.py new file mode 100644 index 00000000..567d6a9d --- /dev/null +++ b/tests/ptmath/roots/test_quadratic.py @@ -0,0 +1,103 @@ +import numpy as np +import pytest + +from pooltool.ptmath.roots.quadratic import solve + + +def test_solve_standard_quadratic(): + # x^2 - 5x + 6 = 0 + # Solutions: x = 2, x = 3 + u1, u2 = solve(1.0, -5.0, 6.0) + solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag)) + # First root -> 2.0 + 0.0j + assert solutions[0].real == 2.0 + assert solutions[0].imag == 0.0 + # Second root -> 3.0 + 0.0j + assert solutions[1].real == 3.0 + assert solutions[1].imag == 0.0 + + # x^2 - x - 2 = 0 + # Solutions: x = -1, x = 2 + u1, u2 = solve(1.0, -1.0, -2.0) + solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag)) + # First root -> -1.0 + 0.0j + assert solutions[0].real == -1.0 + assert solutions[0].imag == 0.0 + # Second root -> 2.0 + 0.0j + assert solutions[1].real == 2.0 + assert solutions[1].imag == 0.0 + + # Perfect square: x^2 - 4x + 4 = 0 + # Single repeated solution: x = 2 + u1, u2 = solve(1.0, -4.0, 4.0) + solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag)) + # Both roots -> 2.0 + 0.0j + for root in solutions: + assert root.real == 2.0 + assert root.imag == 0.0 + + # Difference of squares: x^2 - 4 = 0 + # Solutions: x = -2, x = 2 + u1, u2 = solve(1.0, 0.0, -4.0) + solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag)) + # First root -> -2.0 + 0.0j + assert solutions[0].real == -2.0 + assert solutions[0].imag == 0.0 + # Second root -> 2.0 + 0.0j + assert solutions[1].real == 2.0 + assert solutions[1].imag == 0.0 + + # Complex roots: x^2 + 1 = 0 + # Solutions: x = i, x = -i + u1, u2 = solve(1.0, 0.0, 1.0) + solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag)) + # First root -> -i -> (0.0, -1.0) + assert solutions[0].real == 0.0 + assert solutions[0].imag == -1.0 + # Second root -> i -> (0.0, 1.0) + assert solutions[1].real == 0.0 + assert solutions[1].imag == 1.0 + + +def test_solve_large_values(): + """Test large coefficients for numerical stability.""" + + # Equation: x^2 - 1e7*x + 1 = 0 + # This should give one very large and one very small solution. + a, b, c = 1.0, -1e7, 1.0 + u1, u2 = solve(a, b, c) + solutions = sorted([u1, u2], key=lambda z: (z.real, z.imag)) + + # The large root should be close to 1e7, the smaller should be close to 1e-7. We're + # able to use a very small relative tolerance due to the way the solver avoids + # catastrophic cancellation. + assert pytest.approx(solutions[0].real, rel=1e-12) == 1e-7 + assert pytest.approx(solutions[1].real, rel=1e-12) == 1e7 + + assert solutions[0].imag == 0.0 + assert solutions[1].imag == 0.0 + + +def test_solve_linear_equation(): + # a=0, b≠0 => linear equation b*t + c = 0 => t=-c/b + # e.g. 2t + 4 = 0 => t=-2 + r1, r2 = solve(0.0, 2.0, 4.0) + assert r1.real == -2.0 + assert r1.imag == 0.0 + assert np.isnan(r2) + + +def test_solve_degenerate_no_solution(): + # a=0, b=0, c≠0 => no solutions + # e.g. 0*t^2 + 0*t + 5 = 0 => no real solution + r1, r2 = solve(0.0, 0.0, 5.0) + assert np.isnan(r1) + assert np.isnan(r2) + + +def test_solve_degenerate_infinite_solutions(): + # a=0, b=0, c=0 => infinite solutions + # e.g. 0*t^2 + 0*t + 0 = 0 => t can be anything + r1, r2 = solve(0.0, 0.0, 0.0) + assert np.isnan(r1) + assert np.isnan(r2)