Skip to content

Commit f88740f

Browse files
committed
changed _internal_coeffs() to not squeeze wavelen axis
1 parent e8b588b commit f88740f

File tree

2 files changed

+11
-14
lines changed

2 files changed

+11
-14
lines changed

pymie/mie.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,9 @@ def _internal_coeffs(m, x, n_max, eps1 = DEFAULT_EPS1, eps2 = DEFAULT_EPS2):
612612
dl = (m[..., np.newaxis] * ratio * (D3x - D1x)
613613
/ (m[..., np.newaxis] * D3x - D1mx))
614614
# start from l = 1
615-
return np.array([cl[..., 1:], dl[..., 1:]]).squeeze()
615+
cldl = np.array([cl[..., 1:], dl[..., 1:]])
616+
# remove unneeded layer axis
617+
return cldl[..., 0, :]
616618

617619
def _trans_coeffs(m, x, n_max, eps1 = DEFAULT_EPS1, eps2 = DEFAULT_EPS2):
618620
'''

pymie/tests/test_mie_vectorized.py

Lines changed: 8 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -468,21 +468,16 @@ def test_vectorized_internal_coeffs(self, num_wavelen, num_layer):
468468
coeffs = mie._internal_coeffs(m, x, nstop)
469469

470470
# make sure shape is correct
471-
if num_wavelen == 1:
472-
expected_shape = (2, nstop)
473-
assert coeffs.shape == expected_shape
474-
# no further test since no loop required in this case
475-
else:
476-
expected_shape = (2, num_wavelen, nstop)
477-
assert coeffs.shape == expected_shape
471+
expected_shape = (2, num_wavelen, nstop)
472+
assert coeffs.shape == expected_shape
478473

479-
# we should get same values from loop
480-
coeffs_loop = np.zeros(expected_shape, dtype=complex)
481-
for i in range(m.shape[0]):
482-
c = mie._internal_coeffs(m[i], x[i], nstop)
483-
coeffs_loop[:, i] = c
474+
# we should get same values from loop
475+
coeffs_loop = []
476+
for i in range(m.shape[0]):
477+
coeffs_loop.append(mie._internal_coeffs(m[i], x[i], nstop))
478+
coeffs_loop = np.concatenate(coeffs_loop, axis=1)
484479

485-
assert_equal(coeffs, coeffs_loop)
480+
assert_equal(coeffs, coeffs_loop)
486481

487482
@pytest.mark.parametrize("n_matrix",
488483
[1.33, pytest.param(1.33+0.001j),

0 commit comments

Comments
 (0)