Skip to content

Commit 77a9c02

Browse files
authored
v0.0.4 (#7)
1 parent 150e33c commit 77a9c02

12 files changed

Lines changed: 158 additions & 132 deletions

hypersolver/__init__.py

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
hypersolver revolves around solving hyperbolic
44
partial differential equations (PDEs) of the form
55
6-
∂n/∂t + ∂(fn)/∂x = ∂n/∂t + f ∂n/∂x - n ∂f/∂x = g
6+
∂n/∂t + ∂(fn)/∂x = ∂n/∂t + f ∂n/∂x + n ∂f/∂x = g
77
88
where
99
@@ -23,7 +23,8 @@
2323
2424
available methods:
2525
- "lax_friedrichs" (default)
26-
- "lax_wendroff"
26+
- "lax_wendroff" (still unstable, wip)
27+
- "method_of_characteristics" (experimental)
2728
2829
"""
2930

@@ -36,11 +37,11 @@
3637
"method_of_characteristics",
3738
]
3839

39-
from hypersolver.basic_solver import solver
40+
from hypersolver.step_solver import solver_
4041

4142

42-
def select_solver(method="lax_friedrichs"):
43+
def solver(*args, method="lax_friedrichs", **kwargs):
4344
""" wrapper function to select solvers """
4445
if method not in __hyper_solvers__:
4546
raise ValueError("method not supported")
46-
return solver(method)
47+
return solver_(*args, method=method, **kwargs)

hypersolver/accurate_derivative.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,20 +3,26 @@
33
import numpy as np
44
from scipy.misc import central_diff_weights
55

6+
weights = np.zeros((4, 4))
7+
weights[0, -1:] = np.array([-1/2])
8+
weights[1, -2:] = np.array([1/12, -2/3])
9+
weights[2, -3:] = np.array([-1/60, 3/20, -3/4])
10+
weights[3, -4:] = np.array([1/280, -4/105, 1/5, -4/5])
11+
612

713
def _derivative(_func, _xvar, _nacc):
814
""" 1st derivative following central finite differencing
915
"""
1016
derivative = np.zeros_like(_xvar)
1117
derivative[0] = (_func[1] - _func[0])/(_xvar[1] - _xvar[0])
1218
derivative[-1] = (_func[-1] - _func[-2])/(_xvar[-1] - _xvar[-2])
13-
weights = central_diff_weights(_nacc+1)
19+
_weights = central_diff_weights(_nacc+1)
1420

1521
for nac, idx in zip(range(2, _nacc + 1, 2), range(_nacc//2-1, -1, -1)):
1622
derivative[nac//2:-nac//2] += (
17-
weights[idx]*_func[:-nac] /
23+
_weights[idx]*_func[:-nac] /
1824
((_xvar[nac:] - _xvar[:-nac]) / float(nac)) +
19-
weights[-idx-1]*_func[nac:] /
25+
_weights[-idx-1]*_func[nac:] /
2026
((_xvar[nac:] - _xvar[:-nac]) / float(nac)))
2127

2228
if _nacc > 2:
@@ -26,6 +32,30 @@ def _derivative(_func, _xvar, _nacc):
2632
return derivative
2733

2834

35+
def acc4_derivative(_func, _xvar, _nacc=4):
36+
""" 1st derivative, central differencing, accuracy up to 4
37+
"""
38+
derivative = np.zeros_like(_xvar)
39+
derivative[0] = (_func[1] - _func[0])/(_xvar[1] - _xvar[0])
40+
derivative[-1] = (_func[-1] - _func[-2])/(_xvar[-1] - _xvar[-2])
41+
42+
for nac, idx in zip(range(2, _nacc + 1, 2), range(_nacc//2)):
43+
44+
derivative[nac//2] += (
45+
weights[nac//2-1, -idx-1]*(_func[0] - _func[nac]) /
46+
((_xvar[nac] - _xvar[0]) / float(nac)))
47+
48+
derivative[-nac//2-1] += (
49+
weights[nac//2-1, -idx-1]*(_func[-nac-1] - _func[-1]) /
50+
((_xvar[-1] - _xvar[-nac-1]) / float(nac)))
51+
52+
derivative[nac//2+1:-nac//2-1] += (
53+
weights[_nacc//2-1, -idx-1]*(_func[1:-nac-1] - _func[nac+1:-1]) /
54+
((_xvar[nac+1:-1] - _xvar[1:-nac-1]) / float(nac)))
55+
56+
return derivative
57+
58+
2959
def acc_derivative(func, xvar, nacc):
3060
""" 1st derivative of a _func wrt _xvar with _nacc accuracy
3161
@@ -40,15 +70,15 @@ def acc_derivative(func, xvar, nacc):
4070
NOTES:
4171
minor bug necessitates repeating calculation if nacc > 2
4272
"""
43-
if nacc == 0 or nacc % 2 == 1:
73+
if nacc <= 0 or nacc % 2 == 1:
4474
raise ValueError("n must be positive even")
4575

76+
if nacc < 6:
77+
return acc4_derivative(func, xvar, nacc)
78+
4679
axx_derivative = np.zeros_like(xvar)
4780
axx_derivative = _derivative(func, xvar, 2)
4881

49-
if nacc == 2:
50-
return axx_derivative
51-
5282
for nax in range(4, nacc + 1, 2):
5383
derivative = _derivative(func, xvar, nax)
5484
nan_idx = np.isnan(derivative)

hypersolver/lax_wendroff.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,16 @@ def lw_next(
2828
2929
numerics
3030
--------
31+
n(j+1, i) = (
32+
n(j,i) +
33+
time_step * ( -(fn)x + g ) + # first term (order)
34+
0.5 * (time_step)**2 * (
35+
-fx (-(fn)x + g) - f(-(fn)xx + gx) + gt) # second term (order)
36+
)
37+
)
38+
39+
# FIXME: wrong implementation for now
40+
# TODO: to be revisited
3141
n(j+1, i) = (
3242
n(j,i) -
3343
time_step / (x(i+1) - x(i-1)) * (
@@ -53,9 +63,9 @@ def lw_next(
5363
1.0 * time_step / (vars_vals[2:] - vars_vals[:-2]) * (
5464
init_vals[2:] * flux_term[2:] -
5565
init_vals[:-2] * flux_term[:-2]
56-
) + (time_step / (vars_vals[2:] - vars_vals[:-2])/2)**2 * (
66+
) + 0.5 * (time_step / ((vars_vals[2:] - vars_vals[:-2])/2.0))**2.0 * (
5767
init_vals[:-2] * flux_term[:-2] -
58-
2 * init_vals[1:-1] * flux_term[1:-1] +
68+
2.0 * init_vals[1:-1] * flux_term[1:-1] +
5969
init_vals[2:] * flux_term[2:]
6070
) + sink_term[1:-1] * time_step
6171
)
@@ -65,7 +75,7 @@ def lw_next(
6575
1.0 * time_step / (vars_vals[1] - vars_vals[0]) * (
6676
init_vals[1] * flux_term[1] -
6777
init_vals[0] * flux_term[0]
68-
) + (time_step / (vars_vals[1] - vars_vals[0])/1)**2 * (
78+
) + 0.5 * ((time_step / (vars_vals[1] - vars_vals[0])/1.0))**2.0 * (
6979
0.0 * init_vals[0] * flux_term[0] -
7080
2.0 * init_vals[0] * flux_term[0] +
7181
1.0 * init_vals[1] * flux_term[1]
@@ -77,7 +87,7 @@ def lw_next(
7787
1.0 * time_step / (vars_vals[-1] - vars_vals[-2]) * (
7888
init_vals[-1] * flux_term[-1] -
7989
init_vals[-2] * flux_term[-2]
80-
) + (time_step / (vars_vals[-1] - vars_vals[-2])/1)**2 * (
90+
) + 0.5 * ((time_step / (vars_vals[-1] - vars_vals[-2])/1.0))**2.0 * (
8191
1.0 * init_vals[-2] * flux_term[-2] -
8292
2.0 * init_vals[-1] * flux_term[-1] +
8393
0.0 * init_vals[-1] * flux_term[-1]

hypersolver/method_of_characteristics.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -2,17 +2,17 @@
22
"""
33
import numpy as np
44
from scipy.integrate import odeint
5+
from scipy.interpolate import interp1d
56

67
from hypersolver.accurate_derivative import acc_derivative
78

89

9-
def moc(
10+
def moc_next(
1011
init_vals,
1112
vars_vals,
12-
time_span,
1313
flux_term,
1414
sink_term,
15-
**kwargs
15+
time_step,
1616
):
1717
""" method of characteristics
1818
@@ -29,20 +29,18 @@ def moc(
2929
------
3030
init_step: n
3131
vars_vals: x
32-
time_span:
3332
flux_term: f
3433
sink_term: g
34+
time_step:
3535
3636
outputs
3737
-------
38-
next_vars: x
3938
next_vals: n
4039
4140
numerics
4241
--------
43-
use from scipy.integrate.odeint
42+
use scipy.integrate.odeint
4443
"""
45-
_ = kwargs
4644

4745
# pylint: disable=unused-argument
4846
# pylint: disable=unused-variable
@@ -66,7 +64,14 @@ def _func(yval, tval):
6664
yval0 = np.empty((vars_vals.size + init_vals.size))
6765
yval0[::2] = vars_vals
6866
yval0[1::2] = init_vals
69-
tspan = np.linspace(time_span[0], time_span[-1], 100)
67+
tspan = np.linspace(0, time_step, 10)
7068
results = odeint(_func, yval0, tspan, ml=2, mu=2)
7169

72-
return (results[:, ::2], results[:, 1::2])
70+
fill = interp1d(
71+
results[-1, ::2],
72+
results[-1, 1::2],
73+
fill_value=(0.0, 0.0),
74+
bounds_error=False,
75+
kind='cubic')
76+
77+
return fill(vars_vals)
Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -1,23 +1,25 @@
11
""" shared solver between schemes
22
"""
3+
34
import numpy as np
45

56
from hypersolver.util import term_util
67
from hypersolver.lax_friedrichs import lx_next
78
from hypersolver.lax_wendroff import lw_next
8-
from hypersolver.method_of_characteristics import moc
9+
from hypersolver.method_of_characteristics import moc_next
910

1011

11-
def solver(method): # noqa: C901
12+
def solver_(*args, **kwargs):
1213
""" set the solver """
14+
method = kwargs.get("method", "lax_friedrichs")
1315
if method == "lax_friedrichs":
1416
next_step = lx_next
15-
if method == "lax_wendroff":
17+
elif method == "lax_wendroff":
1618
next_step = lw_next
17-
if method == "method_of_characteristics":
18-
return moc
19+
else:
20+
next_step = moc_next
1921

20-
def _solver( # pylint: disable=too-many-arguments
22+
def _solver_(
2123
init_vals,
2224
vars_vals,
2325
time_span,
@@ -39,11 +41,15 @@ def _solver( # pylint: disable=too-many-arguments
3941
flux_term = term_util(flux_term, init_vals)
4042
sink_term = term_util(sink_term, init_vals)
4143

42-
time_step = (
43-
stability_factor *
44-
np.diff(vars_vals).min() /
45-
np.abs(flux_term).max()
46-
)
44+
stability_factor, time_step = (stability_factor,
45+
stability_factor *
46+
np.diff(vars_vals).min() /
47+
np.abs(flux_term).max()
48+
) if method in [
49+
"lax_friedrichs", "lax_wendroff"
50+
] else (
51+
np.array((time_span[-1] - time_span[0])/5.0),
52+
np.array((time_span[-1] - time_span[0])/5.0))
4753

4854
tidx = np.arange(time_span[0], time_span[-1]+time_step, time_step)
4955
sols = np.zeros((tidx.size, init_vals.size))
@@ -74,4 +80,4 @@ def _solver( # pylint: disable=too-many-arguments
7480

7581
return sols
7682

77-
return _solver
83+
return _solver_(*args, **kwargs)

hypersolver/tests/lax_friedrichs_test.py

Lines changed: 0 additions & 36 deletions
This file was deleted.

hypersolver/tests/lax_wendroff_test.py

Lines changed: 0 additions & 37 deletions
This file was deleted.

hypersolver/tests/method_of_characteristics_test.py

Lines changed: 0 additions & 20 deletions
This file was deleted.

0 commit comments

Comments
 (0)