Skip to content

Commit 9fad8b4

Browse files
committed
generalized broadcasting in mie special functions
1 parent a8efbc8 commit 9fad8b4

File tree

1 file changed

+20
-19
lines changed

1 file changed

+20
-19
lines changed

pymie/mie_specfuncs.py

Lines changed: 20 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -233,15 +233,14 @@ def Qratio(z1, z2, nstop, dns1 = None, dns2 = None,
233233
-----
234234
Logarithmic derivatives calculated automatically if not specified.
235235
236-
Inputs z1 and z2 should be 2d complex arrays with shape [num_values,
237-
num_layers], where num_values could be the number of wavelengths or other
238-
variable.
236+
Inputs z1 and z2 should be complex arrays with shape (..., num_layers),
237+
where "..." means any number of leading dimensions.
239238
240239
Parameters
241240
----------
242-
z1 : array-like with shape [num_values, num_layers]
241+
z1 : array-like with shape (..., num_layers)
243242
m for layer * x for previous layer
244-
z2 : array-like with shape [num_values, num_layers]
243+
z2 : array-like with shape (..., num_layers)
245244
m for layer * x for layer
246245
nstop : integer
247246
maximum order of computation
@@ -252,10 +251,10 @@ def Qratio(z1, z2, nstop, dns1 = None, dns2 = None,
252251
253252
Returns
254253
-------
255-
Qnl : array-like with shape [num_values, num_layers, order]
254+
Qnl : array-like with shape (..., num_layers, order)
256255
Q_n^l for all values (e.g. wavelengths) and layers in z
257256
"""
258-
if (dns1 is None) and (dns2 is None):
257+
if (dns1 is None) or (dns2 is None):
259258
logdersz1 = log_der_13(z1, nstop, eps1, eps2)
260259
logdersz2 = log_der_13(z2, nstop, eps1, eps2)
261260
d1z1 = logdersz1[0]
@@ -275,14 +274,15 @@ def Qratio(z1, z2, nstop, dns1 = None, dns2 = None,
275274
b2 = np.imag(z2)
276275
qns0 = (np.exp(-2.*(b2-b1)) * (np.exp(-1j*2.*a1)-np.exp(-2.*b1))
277276
/ (np.exp(-1j*2.*a2) - np.exp(-2.*b2)))
278-
# shape is [num_values, num_layers, order]
279-
qns0 = qns0[:, :, np.newaxis]
277+
# resulting shape is (..., num_layers, order)
278+
qns0 = qns0[..., np.newaxis]
280279

281280
# Vectorized loop (using np.cumprod) to do upwards recursion in eqn. 33
282-
irange = np.arange(1, nstop+1)
283-
# shape is [num_values, num_layers, order]
284-
i_over_z1 = irange[np.newaxis, np.newaxis, :]/z1[:, :, np.newaxis]
285-
i_over_z2 = irange[np.newaxis, np.newaxis, :]/z2[:, :, np.newaxis]
281+
irange = np.arange(1, nstop+1)[np.newaxis, :]
282+
# irange shape is (1, order) where 1 is for layer axis; Need to add order
283+
# axis to z1, z2 for broadcasting
284+
i_over_z1 = irange / z1[..., np.newaxis]
285+
i_over_z2 = irange / z2[..., np.newaxis]
286286
prod = ((d3z1[..., 1:] + i_over_z1) * (d1z2[..., 1:] + i_over_z2)
287287
/ ((d3z2[..., 1:] + i_over_z2) * (d1z1[..., 1:] + i_over_z1)))
288288
qns = np.concatenate((qns0, qns0 * np.cumprod(prod, axis=-1)), axis=-1)
@@ -306,17 +306,18 @@ def R_psi(z1, z2, nmax, eps1 = DEFAULT_EPS1, eps2 = DEFAULT_EPS2):
306306
307307
See Mackowski eqns. 65-66.
308308
309-
z1, z2 are complex arrays with shape [num_values, 1]
309+
z1, z2 are complex arrays with shape (..., 1)
310310
'''
311311
# Vectorized loop (using np.cumprod) to do up recursion
312-
output_0 = (np.sin(z1) / np.sin(z2))[:, :, np.newaxis]
312+
output_0 = (np.sin(z1) / np.sin(z2))[..., np.newaxis]
313313
dnz1 = dn_1_down(z1, nmax + 1, nmax, lentz_dn1(z1, nmax + 1, eps1, eps2))
314314
dnz2 = dn_1_down(z2, nmax + 1, nmax, lentz_dn1(z2, nmax + 1, eps1, eps2))
315315

316-
irange = np.arange(1, nmax+1)
317-
# shape is [num_values, num_layers, order]
318-
i_over_z1 = irange[np.newaxis, np.newaxis, :]/z1[:, :, np.newaxis]
319-
i_over_z2 = irange[np.newaxis, np.newaxis, :]/z2[:, :, np.newaxis]
316+
irange = np.arange(1, nmax+1)[np.newaxis, :]
317+
# irange shape is (1, order) where 1 is for layer axis; Need to add order
318+
# axis to z1, z2 for broadcasting
319+
i_over_z1 = irange / z1[..., np.newaxis]
320+
i_over_z2 = irange / z2[..., np.newaxis]
320321
prod = (dnz2[..., 1:] + i_over_z2) / (dnz1[..., 1:] + i_over_z1)
321322
output_vec = np.concatenate((output_0,
322323
output_0 * np.cumprod(prod, axis=-1)),

0 commit comments

Comments
 (0)