Skip to content

Commit d351dda

Browse files
authored
v0.0.8 (#14)
1 parent 74aab48 commit d351dda

9 files changed

Lines changed: 84 additions & 12 deletions

File tree

.devcontainer/devcontainer.json

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@
4040
"ms-vscode.vscode-github-issue-notebooks",
4141
"ms-vscode.vscode-markdown-notebook",
4242
"ms-python.gather",
43-
"github.copilot",
4443
"github.vscode-pull-request-github",
4544
"visualstudioexptteam.vscodeintellicode",
4645
"ms-vscode.github-issues-prs",

.github/workflows/test.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ jobs:
1414
fail-fast: false
1515
matrix:
1616
python-version: ["3.7", "3.8", "3.9", "3.10"]
17-
backend: ["jax", "numpy", "numba"]
17+
backend: ["numpy", "numba"]
1818
os: ["ubuntu-latest"]
1919

2020
steps:

hypersolver/__init__.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,21 +37,27 @@
3737

3838
import os
3939

40-
from hypersolver.util import jxt as jit
40+
from hypersolver.util import xnp
41+
from hypersolver.util import jxt
42+
4143
from hypersolver.lax_friedrichs import lx_loop
44+
4245
from hypersolver.lax_wendroff import lw_loop
46+
4347
from hypersolver.runge_kutta import rk_loop
4448

4549

46-
__version__ = "0.0.7"
50+
__version__ = "0.0.8"
4751

4852
__hyper_methods__ = [
4953
"lax_friedrichs",
5054
"lax_wendroff",
51-
"method_of_characteristics",
5255
"runge_kutta_2",
5356
]
5457

58+
np = xnp
59+
jit = jxt
60+
5561

5662
def set_solver(
5763
method=os.environ.get("HS_METHOD", "lax_friedrichs"),

hypersolver/derivative.py

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,9 @@
22
33
available functions:
44
- ord1_acc2: order=1, accuracy=2
5+
- ord1_acc4: order=1, accuracy=4
56
- ord2_acc2: order=2, accuracy=2
7+
- ord2_acc4: order=2, accuracy=4
68
"""
79

810
from hypersolver.util import xnp as np
@@ -79,3 +81,39 @@ def ord2_acc2(_func, _xvar):
7981
)
8082

8183
return np.concatenate((_derivative0, _derivative, _derivative1))
84+
85+
86+
@jit(nopython=True)
87+
def ord2_acc4(_func, _xvar):
88+
""" central differencing: order=2, accuracy=4 """
89+
90+
_func, _xvar = np.asarray(_func), np.asarray(_xvar)
91+
92+
_result10 = (
93+
_func[:-2] - 2.0*_func[1:-1] + _func[2:]
94+
)/((_xvar[2:] - _xvar[:-2])/2.0)**2
95+
96+
_result20 = (
97+
(4/3)*_func[:-2] - (5/2)*_func[1:-1] + (4/3)*_func[2:]
98+
)/((_xvar[2:] - _xvar[:-2])/2.0)**2
99+
100+
_result2 = np.concatenate((
101+
np.array([_result10[0]]),
102+
(
103+
-(1/12)*_func[:-4] - (1/12)*_func[4:]
104+
)/((_xvar[4:] - _xvar[:-4])/4.0)**2 + _result20[1:-1],
105+
np.array([_result10[-1]]),
106+
))
107+
108+
_derivative = _result2
109+
110+
_derivative0 = np.asarray(
111+
[(_func[2] - 2.0*_func[1] + _func[0])/(_xvar[1] - _xvar[0])**2]
112+
)
113+
_derivative1 = np.asarray(
114+
[(_func[-1] - 2.0*_func[-2] + _func[-3])/(_xvar[-1] - _xvar[-2])**2]
115+
)
116+
117+
return np.concatenate(
118+
(_derivative0, _derivative, _derivative1)
119+
)

hypersolver/lax_wendroff.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from hypersolver.util import jxt as jit
3636
from hypersolver.util import xnp as np
3737
from hypersolver.util import term_util, time_step_util
38+
3839
from hypersolver.derivative import ord1_acc2, ord2_acc2
3940

4041

hypersolver/runge_kutta.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
init_: n
88
vars_: x
99
func_: f
10+
sink_: 0
1011
time_: Δt
1112
1213
outputs:
@@ -15,7 +16,7 @@
1516
1617
numerics:
1718
---------
18-
- next_vals: n + Δt*f(n + Δt/2*f(n, x))
19+
- rk2 next_: n + Δt*f(n + Δt/2*f(n, x))
1920
2021
"""
2122

hypersolver/tests/derivative_test.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55

66
from hypersolver.util import xnp as np
77
from hypersolver.derivative import ord1_acc2, ord2_acc2
8-
from hypersolver.derivative import ord1_acc4
8+
from hypersolver.derivative import ord1_acc4, ord2_acc4
99

1010
xvar = np.linspace(0, 10, 1000)
1111
yvar = np.sin(xvar)
@@ -75,3 +75,31 @@ def test_ord2_acc2():
7575
assert (ord2_acc2(
7676
yvar, xvar) - d2ydx2).sum() == pytest.approx(
7777
(d2ydx2-d2ydx2).sum(), abs=1e-1)
78+
79+
80+
def test_ord2_acc4():
81+
""" test: central differncing: order=2, accuracy=2 """
82+
83+
assert ord2_acc4(
84+
yvar, xvar).shape == d2ydx2.shape
85+
86+
assert ord2_acc4(
87+
yvar, xvar) == pytest.approx(d2ydx2, abs=1e-1)
88+
89+
if os.environ.get("HS_BACKEND", "numpy") == "numpy":
90+
assert (ord2_acc4(
91+
yvar, xvar) - d2ydx2).sum() == pytest.approx(0.0, abs=1e-1)
92+
else:
93+
assert (ord2_acc4(
94+
yvar, xvar) - d2ydx2).sum() == pytest.approx(
95+
(d2ydx2-d2ydx2).sum(), abs=1e-1)
96+
97+
98+
def test_comp_ord2():
99+
""" test: compare ord1 acc(2,4) """
100+
101+
assert (
102+
(ord2_acc4(yvar, xvar)-d2ydx2)**2
103+
).sum() <= (
104+
(ord2_acc2(yvar, xvar)-d2ydx2)**2
105+
).sum()

hypersolver/util.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,8 @@ def set_xnp(backend=os.environ.get("HS_BACKEND", "numpy")):
1414

1515
if backend == "jax":
1616
warnings.warn(
17-
"experimental jax support is suboptimal with no performance gain")
18-
import jax.numpy as jnp # pylint: disable=import-outside-toplevel
19-
return jnp
17+
f"no more {backend} support, reverting to numpy")
18+
return np
2019

2120
return np
2221

@@ -66,7 +65,7 @@ def term_util(term, orig):
6665
xnp.asarray(term, dtype=orig.dtype), orig)[0]
6766

6867

69-
# @jxt(nopython=True, parallel=True)
68+
# @jxt(nopython=True)
7069
def func_util(func, _vals, _vars, **kwargs):
7170
""" evaluate function if one """
7271
return func(_vals, _vars, **kwargs) if callable(func) else func

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@ repository = "https://github.com/ngam/hypersolver"
3131
dev = [
3232
"pylint", "pytest", "autopep8", "jupyterlab", "pytype",
3333
"typing", "build", "flake8", "jupyter-book", "ghp-import",
34-
"matplotlib", "jax", "jaxlib", "numba", "pandas", "sympy",
34+
"matplotlib", "numba",
3535
]
3636

3737
[build-system]

0 commit comments

Comments
 (0)