Skip to content

Commit a1f401a

Browse files
committed
changed calc_g() to not squeeze wavelen axis
1 parent edc888d commit a1f401a

File tree

2 files changed

+11
-16
lines changed

2 files changed

+11
-16
lines changed

pymie/mie.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -206,7 +206,7 @@ def calc_g(m, x, nstop=None):
206206
coeffs = _scatcoeffs(m, x, nstop)
207207

208208
# for multilayer particle, need to scale by the x of the outermost layer
209-
outer_x = np.array(x).max(axis=-1).squeeze()
209+
outer_x = np.array(x).max(axis=-1)
210210
cscat = _cross_sections(coeffs[0], coeffs[1])[0] * 2./outer_x**2
211211
g = ((4./(outer_x**2 * cscat))
212212
* _asymmetry_parameter(coeffs[0], coeffs[1]))

pymie/tests/test_mie_vectorized.py

Lines changed: 10 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -765,21 +765,16 @@ def test_vectorized_asymmetry_parameter(self, num_wavelen, num_layer):
765765
m, x = mx(num_wavelen, num_layer, **self.mxargs)
766766
# make sure shape is [num_wavelen]
767767
g = mie.calc_g(m,x)
768-
if num_wavelen == 1:
769-
expected_shape = ()
770-
assert g.shape == expected_shape
771-
# no further test needed since no loop is required in this case
772-
else:
773-
expected_shape = (num_wavelen,)
774-
assert g.shape == expected_shape
775-
776-
# we should get same values from loop. Need to set nstop to the
777-
# same value as used in the vectorized calculation.
778-
g_loop = np.zeros(expected_shape, dtype=float)
779-
nstop = mie._nstop(x.max())
780-
for i in range(num_wavelen):
781-
g_loop[i] = mie.calc_g(m[i], x[i], nstop=nstop)
782-
assert_equal(g, g_loop)
768+
expected_shape = (num_wavelen,)
769+
assert g.shape == expected_shape
770+
771+
# we should get same values from loop. Need to set nstop to the
772+
# same value as used in the vectorized calculation.
773+
g_loop = np.zeros(expected_shape, dtype=float)
774+
nstop = mie._nstop(x.max())
775+
for i in range(num_wavelen):
776+
g_loop[i] = mie.calc_g(m[i], x[i], nstop=nstop)
777+
assert_equal(g, g_loop)
783778

784779
@pytest.mark.parametrize("num_wavelen, num_layer",
785780
[(10, 1), (1, 5), (10, 5)])

0 commit comments

Comments
 (0)