Skip to content

Commit 618ae7c

Browse files
committed
allow non-square matrix-valued functions in convolve function
1 parent b1d3e24 commit 618ae7c

1 file changed

Lines changed: 9 additions & 8 deletions

File tree

c++/cppdlr/dlr_imtime.hpp

Lines changed: 9 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -307,7 +307,6 @@ namespace cppdlr {
307307
typename T::regular_type convolve(double beta, statistic_t statistic, T const &fc, T const &gc, bool time_order = false) const {
308308

309309
if (r != fc.shape(0) || r != gc.shape(0)) throw std::runtime_error("First dim of input arrays must be equal to DLR rank r.");
310-
if (fc.shape() != gc.shape()) throw std::runtime_error("Input arrays must have the same shape.");
311310

312311
// TODO: implement bosonic case and remove
313312
if (statistic == 0) throw std::runtime_error("imtime_ops::convolve not yet implemented for bosonic Green's functions.");
@@ -335,17 +334,19 @@ namespace cppdlr {
335334

336335
} else if (T::rank == 3) { // Matrix-valued Green's function
337336

337+
if (fc.shape(2) != gc.shape(1)) throw std::runtime_error("Input array dimensions incompatible.");
338+
338339
// Diagonal contribution
339-
auto fcgc = nda::array<S, 3>(fc.shape()); // Product of coefficients of f and g
340-
for (int i = 0; i < r; ++i) { fcgc(i, _, _) = matmul(fc(i, _, _), gc(i, _, _)); }
341-
auto h = arraymult(tcf2it_v, fcgc);
340+
auto tmp = nda::array<S, 3>(r, fc.shape(1), gc.shape(2)); // m x p temporary array
341+
for (int i = 0; i < r; ++i) { tmp(i, _, _) = matmul(fc(i, _, _), gc(i, _, _)); } // Product of coefficients of f and g
342+
auto h = arraymult(tcf2it_v, tmp);
342343

343344
// Off-diagonal contribution
344-
auto tmp1 = arraymult(hilb_v, fc);
345-
auto tmp2 = arraymult(hilb_v, gc);
346-
for (int i = 0; i < r; ++i) { tmp1(i, _, _) = matmul(tmp1(i, _, _), gc(i, _, _)) + matmul(fc(i, _, _), tmp2(i, _, _)); }
345+
for (int i = 0; i < r; ++i) {
346+
tmp(i, _, _) = matmul(arraymult(hilb_v, fc)(i, _, _), gc(i, _, _)) + matmul(fc(i, _, _), arraymult(hilb_v, gc)(i, _, _));
347+
}
347348

348-
return beta * (h + arraymult(cf2it, tmp1));
349+
return beta * (h + arraymult(cf2it, tmp));
349350

350351
} else {
351352
throw std::runtime_error("Input arrays must be rank 1 (scalar-valued Green's function) or 3 (matrix-valued Green's function).");

0 commit comments

Comments
 (0)