For long term maintenance simplicity, consider replacing the custom cuda kernels redrock.zscan.batch_dot_product_3d3d and batch_dot_product_3d2d with einsum magic as suggested by @dmargala :
For example, batched A.T.dot(A) and A.T.dot(b) would be:
cp.einsum("...ji,...jk", A, A)
cp.einsum("...ji,...j", A, b)
Those aren't a drop-in replacement for the call signature of batch_dot_product_3d3d, but I think we are using it for that A.T.dot(A) purpose. Profile test it against current implementation and also check for correctness.
Also consider moving functions like this into redrock.utils or a separate redrock.linalg or similar module instead of zscan.
For long term maintenance simplicity, consider replacing the custom cuda kernels
redrock.zscan.batch_dot_product_3d3dandbatch_dot_product_3d2dwitheinsummagic as suggested by @dmargala :Those aren't a drop-in replacement for the call signature of
batch_dot_product_3d3d, but I think we are using it for thatA.T.dot(A)purpose. Profile test it against current implementation and also check for correctness.Also consider moving functions like this into
redrock.utilsor a separateredrock.linalgor similar module instead of zscan.