Skip to content

Commit 3bca993

Browse files
committed
changed integrate_intensity_complex_medium to not squeeze
1 parent a075a30 commit 3bca993

File tree

2 files changed

+18
-16
lines changed

2 files changed

+18
-16
lines changed

pymie/mie.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1368,6 +1368,11 @@ def integrate_intensity_complex_medium(dscat, distance, thetas, k,
13681368
if isinstance(dsigma_2, Quantity):
13691369
sigma_2 = Quantity(sigma_2, dsigma_2.units)
13701370

1371+
# k has trailing axes for theta (and possibly phi) that are no longer
1372+
# needed after the integration. We remove them here
1373+
k_shape = k.shape
1374+
k = k.reshape(num_values)
1375+
13711376
# multiply by factor that accounts for attenuation in the incident light
13721377
# (see Sudiarta and Chylek (2001), eq 10).
13731378
# if the imaginary part of k is close to 0 (because the medium index is
@@ -1381,16 +1386,12 @@ def integrate_intensity_complex_medium(dscat, distance, thetas, k,
13811386
1 / (exponent / (2*distance*k.imag)
13821387
+ (1 - exponent) / (2*distance*k.imag)**2))
13831388

1384-
# prepare for broadcasting (this will add trailing axes of size 1)
1385-
sigma_1 = sigma_1.reshape(factor.shape)
1386-
sigma_2 = sigma_2.reshape(factor.shape)
1387-
13881389
# calculate the averaged sigma
13891390
sigma = (sigma_1 + sigma_2)/2 * factor
13901391

1391-
return(sigma.squeeze(), (sigma_1*factor).squeeze(),
1392-
(sigma_2*factor).squeeze(), (dsigma_1*factor/2).squeeze(),
1393-
(dsigma_2*factor/2).squeeze())
1392+
return(sigma, (sigma_1*factor),
1393+
(sigma_2*factor), (dsigma_1*factor.reshape(k_shape)/2),
1394+
(dsigma_2*factor.reshape(k_shape)/2))
13941395

13951396
def diff_abs_intensity_complex_medium(m, x, thetas, ktd):
13961397
'''

pymie/tests/test_mie_vectorized.py

Lines changed: 10 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -689,9 +689,9 @@ def test_vectorized_angular_functions(self, num_wavelen,
689689
# by trying to assign units to each element, so we have to take
690690
# the magnitudes here
691691
units = integral_loop[0].units
692-
sigma[i] = integral_loop[0].magnitude
693-
sigma_1[i] = integral_loop[1].magnitude
694-
sigma_2[i] = integral_loop[2].magnitude
692+
sigma[i] = integral_loop[0].magnitude.squeeze()
693+
sigma_1[i] = integral_loop[1].magnitude.squeeze()
694+
sigma_2[i] = integral_loop[2].magnitude.squeeze()
695695
dsigma_1[i] = integral_loop[3].magnitude
696696
dsigma_2[i] = integral_loop[4].magnitude
697697

@@ -706,11 +706,11 @@ def test_vectorized_angular_functions(self, num_wavelen,
706706
assert_equal(i12[0], i1)
707707
assert_equal(i12[1], i2)
708708

709-
assert_equal(integral[0].magnitude, sigma.squeeze())
710-
assert_equal(integral[1].magnitude, sigma_1.squeeze())
711-
assert_equal(integral[2].magnitude, sigma_2.squeeze())
712-
assert_equal(integral[3].magnitude, dsigma_1.squeeze())
713-
assert_equal(integral[4].magnitude, dsigma_2.squeeze())
709+
assert_equal(integral[0].magnitude, sigma)
710+
assert_equal(integral[1].magnitude, sigma_1)
711+
assert_equal(integral[2].magnitude, sigma_2)
712+
assert_equal(integral[3].magnitude, dsigma_1)
713+
assert_equal(integral[4].magnitude, dsigma_2)
714714

715715

716716
class TestVectorizedUserFunctions():
@@ -940,7 +940,8 @@ def test_vectorized_calc_reflectance(self, n_medium, num_wavelen,
940940
for i in range(num_wavelen):
941941
reflectance = mie.calc_reflectance(radius, n_medium, n_particle[i],
942942
wavelen[i]).magnitude
943-
refl_loop[i] = reflectance
943+
# squeeze to remove singleton wavelength dimension
944+
refl_loop[i] = reflectance.squeeze()
944945

945946
assert_allclose(refl.magnitude, refl_loop, rtol=1e-14)
946947
assert refl.units == 1/wavelen.units**2

0 commit comments

Comments
 (0)