Skip to content

Commit b73f2c2

Browse files
authored
v0.0.5 (#8)
1 parent 77a9c02 commit b73f2c2

15 files changed

Lines changed: 419 additions & 360 deletions

hypersolver/__init__.py

Lines changed: 53 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -12,36 +12,77 @@
1212
f is speed n moves along x, and
1313
g lumps sources and sinks
1414
15-
functionally, n(x; t), f(x), and g(x; t; n)
15+
functionally, n(x; t), f(x), and g(n; x)
1616
1717
note, fn is the flux across x.
1818
1919
Usage:
20-
>>> from hypersolver import select_solver
21-
>>> solver = select_solver(method)
20+
>>> from hypersolver import solver
2221
>>> solver(n0, x, t, f, g, **kwargs)
22+
>>> # kwargs include "method", "backend", etc.
2323
24-
available methods:
24+
available `method`s:
25+
pde:
2526
- "lax_friedrichs" (default)
26-
- "lax_wendroff" (still unstable, wip)
27-
- "method_of_characteristics" (experimental)
27+
- "lax_wendroff"
28+
- "method_of_characteristics" (broken, experimental)
29+
ode:
30+
- "rk2"
2831
32+
available `backend`s:
33+
- "numpy" (default)
34+
- "jax" (experimental)
35+
36+
available `solver_type`s:
37+
- "unsplit" (default)
38+
- "split"
2939
"""
3040

41+
import os
42+
43+
from hypersolver.pde_solver_unsplit import solver_ as solver_upde
44+
# from hypersolver.pde_solver_split import solver_ as solver_spde
45+
from hypersolver.ode_solver import solver_ as solver_ode
3146

32-
__version__ = "0.0.4"
3347

34-
__hyper_solvers__ = [
48+
__version__ = "0.0.5"
49+
50+
__hyper_methods__ = [
3551
"lax_friedrichs",
3652
"lax_wendroff",
3753
"method_of_characteristics",
54+
"rk2",
3855
]
3956

40-
from hypersolver.step_solver import solver_
57+
__hyper_solver_types__ = [
58+
"unsplit",
59+
"split",
60+
]
4161

4262

43-
def solver(*args, method="lax_friedrichs", **kwargs):
63+
def solver(
64+
*args,
65+
method="lax_friedrichs",
66+
backend=os.environ.get("HS_BACKEND", "numpy"),
67+
verbosity=os.environ.get("HS_VERBOSITY", "0"),
68+
solver_type="unsplit",
69+
**kwargs
70+
):
4471
""" wrapper function to select solvers """
45-
if method not in __hyper_solvers__:
72+
73+
os.environ["HS_BACKEND"] = str(backend)
74+
os.environ["HS_VERBOSITY"] = str(verbosity)
75+
76+
if method not in __hyper_methods__ or \
77+
solver_type not in __hyper_solver_types__:
4678
raise ValueError("method not supported")
47-
return solver_(*args, method=method, **kwargs)
79+
80+
if method.startswith("rk"):
81+
return solver_ode(*args, method=method, **kwargs)
82+
# if method.endswith("_split"):
83+
# return solver_spde(*args, method=method, **kwargs)
84+
return solver_upde(
85+
*args,
86+
method=method,
87+
**kwargs
88+
)

hypersolver/accurate_derivative.py

Lines changed: 0 additions & 88 deletions
This file was deleted.

hypersolver/derivative.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
""" derivative: calculating derivatives central differencing
2+
3+
available functions:
4+
- ord1_acc2: order=1, accuracy=2
5+
- ord2_acc2: order=2, accuracy=2
6+
"""
7+
8+
from hypersolver.util import xnp as np
9+
10+
11+
def ord1_acc2(_func, _xvar):
12+
""" central differencing: order=1, accuracy=2 """
13+
14+
_func, _xvar = np.array(_func), np.array(_xvar)
15+
16+
_derivative = (
17+
_func[2:] - _func[:-2]
18+
)/(_xvar[2:] - _xvar[:-2])
19+
20+
return np.pad(
21+
_derivative,
22+
(1, 1),
23+
mode='constant',
24+
constant_values=(
25+
(_func[1] - _func[0])/(_xvar[1] - _xvar[0]),
26+
(_func[-1] - _func[-2])/(_xvar[-1] - _xvar[-2])
27+
),)
28+
29+
30+
def ord2_acc2(_func, _xvar):
31+
""" central differencing: order=2, accuracy=2 """
32+
33+
_func, _xvar = np.array(_func), np.array(_xvar)
34+
35+
_derivative = (
36+
_func[2:] - 2.0*_func[1:-1] + _func[:-2]
37+
)/((_xvar[2:] - _xvar[:-2])/2.0)**2
38+
39+
return np.pad(
40+
_derivative,
41+
(1, 1),
42+
mode='constant',
43+
constant_values=(
44+
(_func[2] - 2.0*_func[1] + _func[0])/(_xvar[1] - _xvar[0])**2,
45+
(_func[-1] - 2.0*_func[-2] + _func[-3])/(_xvar[-1] - _xvar[-2])**2
46+
),)

hypersolver/lax_friedrichs.py

Lines changed: 28 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,24 @@
11
""" Lax-Friedrics finite-difference scheme """
22

3-
from hypersolver.util import prep_next_step
3+
from hypersolver.util import xnp as np
4+
from hypersolver.util import time_step_util
5+
from hypersolver.derivative import ord1_acc2
6+
7+
8+
def lx_init(init_vals):
9+
""" initialize the array
10+
11+
pad the array with the prescribed first and last values
12+
"""
13+
14+
return np.pad(
15+
init_vals,
16+
(1, 1),
17+
mode="constant",
18+
constant_values=(
19+
0.5 * (init_vals[1] + init_vals[0]),
20+
0.5 * (init_vals[-1] + init_vals[-2])
21+
),)
422

523

624
def lx_next(
@@ -16,7 +34,7 @@ def lx_next(
1634
1735
inputs
1836
------
19-
init_step: n
37+
init_vals: n
2038
vars_vals: x
2139
flux_term: f
2240
sink_term: g
@@ -28,46 +46,16 @@ def lx_next(
2846
2947
numerics
3048
--------
31-
n(j+1, i) = (
32-
0.5 * (n(j,i+1) + n(j+1,i-1)) -
33-
time_step / (x(i+1) - x(i-1)) * (
34-
n(j,i+1) * f(i+1) -
35-
n(j,i-1) * f(i-1)
36-
) +
37-
g(j,i) * time_step
49+
n(j+1, i) = n(j, i) + Δt (g - Δ(fn)/Δx)(j, i)
3850
39-
time_step = (
40-
stability *
41-
(x(i+1) - x(i-1)).min() /
42-
(f(i)).max()
43-
)
51+
Δt ≤ λΔx/f ∀ x
52+
Δ(fn)/Δx is first-order derivative with accuracy of 2
53+
n(j, i) = (n(j, i-1) + n(j, i+1))/2
4454
"""
4555

46-
(time_step, next_vals) = prep_next_step(
47-
stability, vars_vals, flux_term, init_vals)
48-
49-
next_vals[1:-1] = (
50-
0.5 * (init_vals[2:] + init_vals[:-2]) -
51-
1.0 * time_step / (vars_vals[2:] - vars_vals[:-2]) * (
52-
init_vals[2:] * flux_term[2:] -
53-
init_vals[:-2] * flux_term[:-2]
54-
) + sink_term[1:-1] * time_step
55-
)
56-
57-
next_vals[0] = (
58-
0.5 * (init_vals[1] + init_vals[0]) -
59-
1.0 * time_step / (vars_vals[1] - vars_vals[0]) * (
60-
init_vals[1] * flux_term[1] -
61-
init_vals[0] * flux_term[0]
62-
) + sink_term[0] * time_step
63-
)
56+
time_step = time_step_util(vars_vals, flux_term, stability)
6457

65-
next_vals[-1] = (
66-
0.5 * (init_vals[-1] + init_vals[-2]) -
67-
1.0 * time_step / (vars_vals[-1] - vars_vals[-2]) * (
68-
init_vals[-1] * flux_term[-1] -
69-
init_vals[-2] * flux_term[-2]
70-
) + sink_term[-1] * time_step
71-
)
58+
_init_vals = lx_init(0.5 * (init_vals[2:] + init_vals[:-2]))
7259

73-
return next_vals
60+
return _init_vals - time_step * (
61+
ord1_acc2(init_vals*flux_term, vars_vals) - sink_term)

0 commit comments

Comments
 (0)