@@ -233,15 +233,14 @@ def Qratio(z1, z2, nstop, dns1 = None, dns2 = None,
233233 -----
234234 Logarithmic derivatives calculated automatically if not specified.
235235
236- Inputs z1 and z2 should be 2d complex arrays with shape [num_values,
237- num_layers], where num_values could be the number of wavelengths or other
238- variable.
236+ Inputs z1 and z2 should be complex arrays with shape (..., num_layers),
237+ where "..." means any number of leading dimensions.
239238
240239 Parameters
241240 ----------
242- z1 : array-like with shape [num_values , num_layers]
241+ z1 : array-like with shape (... , num_layers)
243242 m for layer * x for previous layer
244- z2 : array-like with shape [num_values , num_layers]
243+ z2 : array-like with shape (... , num_layers)
245244 m for layer * x for layer
246245 nstop : integer
247246 maximum order of computation
@@ -252,10 +251,10 @@ def Qratio(z1, z2, nstop, dns1 = None, dns2 = None,
252251
253252 Returns
254253 -------
255- Qnl : array-like with shape [num_values , num_layers, order]
254+ Qnl : array-like with shape (... , num_layers, order)
256255 Q_n^l for all values (e.g. wavelengths) and layers in z
257256 """
258- if (dns1 is None ) and (dns2 is None ):
257+ if (dns1 is None ) or (dns2 is None ):
259258 logdersz1 = log_der_13 (z1 , nstop , eps1 , eps2 )
260259 logdersz2 = log_der_13 (z2 , nstop , eps1 , eps2 )
261260 d1z1 = logdersz1 [0 ]
@@ -275,14 +274,15 @@ def Qratio(z1, z2, nstop, dns1 = None, dns2 = None,
275274 b2 = np .imag (z2 )
276275 qns0 = (np .exp (- 2. * (b2 - b1 )) * (np .exp (- 1j * 2. * a1 )- np .exp (- 2. * b1 ))
277276 / (np .exp (- 1j * 2. * a2 ) - np .exp (- 2. * b2 )))
278- # shape is [num_values , num_layers, order]
279- qns0 = qns0 [:, : , np .newaxis ]
277+ # resulting shape is (... , num_layers, order)
278+ qns0 = qns0 [... , np .newaxis ]
280279
281280 # Vectorized loop (using np.cumprod) to do upwards recursion in eqn. 33
282- irange = np .arange (1 , nstop + 1 )
283- # shape is [num_values, num_layers, order]
284- i_over_z1 = irange [np .newaxis , np .newaxis , :]/ z1 [:, :, np .newaxis ]
285- i_over_z2 = irange [np .newaxis , np .newaxis , :]/ z2 [:, :, np .newaxis ]
281+ irange = np .arange (1 , nstop + 1 )[np .newaxis , :]
282+ # irange shape is (1, order) where 1 is for layer axis; Need to add order
283+ # axis to z1, z2 for broadcasting
284+ i_over_z1 = irange / z1 [..., np .newaxis ]
285+ i_over_z2 = irange / z2 [..., np .newaxis ]
286286 prod = ((d3z1 [..., 1 :] + i_over_z1 ) * (d1z2 [..., 1 :] + i_over_z2 )
287287 / ((d3z2 [..., 1 :] + i_over_z2 ) * (d1z1 [..., 1 :] + i_over_z1 )))
288288 qns = np .concatenate ((qns0 , qns0 * np .cumprod (prod , axis = - 1 )), axis = - 1 )
@@ -306,17 +306,18 @@ def R_psi(z1, z2, nmax, eps1 = DEFAULT_EPS1, eps2 = DEFAULT_EPS2):
306306
307307 See Mackowski eqns. 65-66.
308308
309- z1, z2 are complex arrays with shape [num_values , 1]
309+ z1, z2 are complex arrays with shape (... , 1)
310310 '''
311311 # Vectorized loop (using np.cumprod) to do up recursion
312- output_0 = (np .sin (z1 ) / np .sin (z2 ))[:, : , np .newaxis ]
312+ output_0 = (np .sin (z1 ) / np .sin (z2 ))[... , np .newaxis ]
313313 dnz1 = dn_1_down (z1 , nmax + 1 , nmax , lentz_dn1 (z1 , nmax + 1 , eps1 , eps2 ))
314314 dnz2 = dn_1_down (z2 , nmax + 1 , nmax , lentz_dn1 (z2 , nmax + 1 , eps1 , eps2 ))
315315
316- irange = np .arange (1 , nmax + 1 )
317- # shape is [num_values, num_layers, order]
318- i_over_z1 = irange [np .newaxis , np .newaxis , :]/ z1 [:, :, np .newaxis ]
319- i_over_z2 = irange [np .newaxis , np .newaxis , :]/ z2 [:, :, np .newaxis ]
316+ irange = np .arange (1 , nmax + 1 )[np .newaxis , :]
317+ # irange shape is (1, order) where 1 is for layer axis; Need to add order
318+ # axis to z1, z2 for broadcasting
319+ i_over_z1 = irange / z1 [..., np .newaxis ]
320+ i_over_z2 = irange / z2 [..., np .newaxis ]
320321 prod = (dnz2 [..., 1 :] + i_over_z2 ) / (dnz1 [..., 1 :] + i_over_z1 )
321322 output_vec = np .concatenate ((output_0 ,
322323 output_0 * np .cumprod (prod , axis = - 1 )),
0 commit comments