diff --git a/perfbench/perfbench_ddbta_cpu.py b/perfbench/perfbench_ddbta_cpu.py new file mode 100644 index 00000000..ff4e0fcb --- /dev/null +++ b/perfbench/perfbench_ddbta_cpu.py @@ -0,0 +1,309 @@ +import time + +tic = time.perf_counter() +import numpy as np +import argparse + +from serinv.utils.check_dd import check_ddbta +from serinv.algs import ddbtasc, ddbtasci + +def generate_dataset( + n_blocks: int, + diagonal_blocksize: int, + arrowhead_blocksize: int, + bsym: bool, + quadratic: bool, + dtype=np.float64, +): + print(f"Generating sequential quadratic dataset...", flush=True) + print(f" - Generating A", flush=True) + rc = (1.0 + 1.0j) if dtype == np.complex128 else 1.0 + A_diagonal_blocks = rc * np.random.rand( + n_blocks, diagonal_blocksize, diagonal_blocksize + ) + A_lower_diagonal_blocks = rc * np.random.rand( + n_blocks - 1, diagonal_blocksize, diagonal_blocksize + ) + A_upper_diagonal_blocks = rc * np.random.rand( + n_blocks - 1, diagonal_blocksize, diagonal_blocksize + ) + # A arrowhead part + A_lower_arrow_blocks = rc * np.random.rand( + n_blocks, arrowhead_blocksize, diagonal_blocksize + ) + A_upper_arrow_blocks = rc * np.random.rand( + n_blocks, + diagonal_blocksize, + arrowhead_blocksize, + ) + A_arrow_tip_block = rc * np.random.rand( + arrowhead_blocksize, arrowhead_blocksize + ) + arrow_colsum = np.zeros((arrowhead_blocksize), dtype=A_diagonal_blocks.dtype) + for i in range(A_diagonal_blocks.shape[0]): + colsum = np.sum(A_diagonal_blocks[i], axis=1) - np.diag( + A_diagonal_blocks[i] + ) + if i > 0: + colsum += np.sum(A_lower_diagonal_blocks[i - 1], axis=1) + if i < n_blocks - 1: + colsum += np.sum(A_upper_diagonal_blocks[i], axis=1) + colsum += np.sum(A_upper_arrow_blocks[i], axis=1) + A_diagonal_blocks[i] += np.diag(colsum) + arrow_colsum[:] += np.sum(A_lower_arrow_blocks[i], axis=1) + A_arrow_tip_block[:, :] += np.diag( + arrow_colsum + np.sum(A_arrow_tip_block[:, :], axis=1) + ) + A_ddbta = check_ddbta( + A_diagonal_blocks, + A_lower_diagonal_blocks, + A_upper_diagonal_blocks, + A_lower_arrow_blocks, + A_upper_arrow_blocks, + A_arrow_tip_block, + ) + + A = { + "A_diagonal_blocks": A_diagonal_blocks, + "A_lower_diagonal_blocks": A_lower_diagonal_blocks, + "A_upper_diagonal_blocks": A_upper_diagonal_blocks, + "A_lower_arrow_blocks": A_lower_arrow_blocks, + "A_upper_arrow_blocks": A_upper_arrow_blocks, + "A_arrow_tip_block": A_arrow_tip_block, + } + + if quadratic: + print(f" - Generating B (Quadratic Equation)", flush=True) + B_diagonal_blocks = rc * np.random.rand( + n_blocks, diagonal_blocksize, diagonal_blocksize + ) + B_lower_diagonal_blocks = rc * np.random.rand( + n_blocks - 1, diagonal_blocksize, diagonal_blocksize + ) + B_upper_diagonal_blocks = rc * np.random.rand( + n_blocks - 1, diagonal_blocksize, diagonal_blocksize + ) + # B arrowhead part + B_lower_arrow_blocks = rc * np.random.rand( + n_blocks, arrowhead_blocksize, diagonal_blocksize + ) + B_upper_arrow_blocks = rc * np.random.rand( + n_blocks, + diagonal_blocksize, + arrowhead_blocksize, + ) + B_arrow_tip_block = rc * np.random.rand( + arrowhead_blocksize, arrowhead_blocksize + ) + arrow_colsum = np.zeros((arrowhead_blocksize), dtype=B_diagonal_blocks.dtype) + for i in range(B_diagonal_blocks.shape[0]): + colsum = np.sum(B_diagonal_blocks[i], axis=1) - np.diag( + B_diagonal_blocks[i] + ) + if i > 0: + colsum += np.sum(B_lower_diagonal_blocks[i - 1], axis=1) + if i < n_blocks - 1: + colsum += np.sum(B_upper_diagonal_blocks[i], axis=1) + colsum += np.sum(B_upper_arrow_blocks[i], axis=1) + B_diagonal_blocks[i] += np.diag(colsum) + arrow_colsum[:] += np.sum(B_lower_arrow_blocks[i], axis=1) + B_arrow_tip_block[:, :] += np.diag( + arrow_colsum + np.sum(B_arrow_tip_block[:, :], axis=1) + ) + + B_ddbta = check_ddbta( + B_diagonal_blocks, + B_lower_diagonal_blocks, + B_upper_diagonal_blocks, + B_lower_arrow_blocks, + B_upper_arrow_blocks, + B_arrow_tip_block, + ) + + if bsym: + for i in range(n_blocks): + B_diagonal_blocks[i] = ( + B_diagonal_blocks[i] + B_diagonal_blocks[i].conj().T + ) / 2 + if i < n_blocks - 1: + B_upper_diagonal_blocks[i] = B_lower_diagonal_blocks[i].conj().T + B_upper_arrow_blocks[i] = B_lower_arrow_blocks[i].conj().T + B_arrow_tip_block = (B_arrow_tip_block + B_arrow_tip_block.conj().T) / 2 + + B = { + "B_diagonal_blocks": B_diagonal_blocks, + "B_lower_diagonal_blocks": B_lower_diagonal_blocks, + "B_upper_diagonal_blocks": B_upper_diagonal_blocks, + "B_lower_arrow_blocks": B_lower_arrow_blocks, + "B_upper_arrow_blocks": B_upper_arrow_blocks, + "B_arrow_tip_block": B_arrow_tip_block, + } + else: + B_ddbta = True + + B = None + + + if np.all(A_ddbta) and np.all(B_ddbta): + print("All rows are diagonally dominant!", flush=True) + else: + raise ValueError("Some rows are not diagonally dominant!") + + + + return A, B + +if __name__ == "__main__": + parser = argparse.ArgumentParser(description="Process some integers.") + parser.add_argument( + "--b", + type=int, + default=128, + help="an integer for the diagonal block size", + ) + parser.add_argument( + "--a", + type=int, + default=0, + help="an integer for the diagonal block size", + ) + parser.add_argument( + "--n", + type=int, + default=8, + help="an integer for the number of diagonal blocks", + ) + parser.add_argument( + "--bsym", + type=bool, + default=True, + help="whether to make B block-symmetric or not", + ) + parser.add_argument( + "--q", + type=int, + help="wether to run the quadratic or not", + ) + args = parser.parse_args() + toc = time.perf_counter() + print(f"Import and parsing took: {toc - tic:.5f} sec", flush=True) + + quadratic = False if args.q == 0 else True + n_iterations = 10 + n_warmups = 2 + + tic = time.perf_counter() + A, B = generate_dataset( + n_blocks = args.n, + diagonal_blocksize = args.b, + arrowhead_blocksize = args.a, + bsym = args.bsym, + quadratic = quadratic, + ) + toc = time.perf_counter() + print(f"Dataset generation took: {toc - tic:.5f} sec", flush=True) + + A_diagonal_blocks_init = A["A_diagonal_blocks"] + A_lower_diagonal_blocks_init = A["A_lower_diagonal_blocks"] + A_upper_diagonal_blocks_init = A["A_upper_diagonal_blocks"] + A_lower_arrow_blocks_init = A["A_lower_arrow_blocks"] + A_upper_arrow_blocks_init = A["A_upper_arrow_blocks"] + A_arrow_tip_block_init = A["A_arrow_tip_block"] + + # Init device arrays + A_diagonal_blocks_cpu = np.empty_like(A_diagonal_blocks_init) + A_lower_diagonal_blocks_cpu = np.empty_like(A_lower_diagonal_blocks_init) + A_upper_diagonal_blocks_cpu = np.empty_like(A_upper_diagonal_blocks_init) + A_lower_arrow_blocks_cpu = np.empty_like(A_lower_arrow_blocks_init) + A_upper_arrow_blocks_cpu = np.empty_like(A_upper_arrow_blocks_init) + A_arrow_tip_block_cpu = np.empty_like(A_arrow_tip_block_init) + + if quadratic: + B_diagonal_blocks_init = B["B_diagonal_blocks"] + B_lower_diagonal_blocks_init = B["B_lower_diagonal_blocks"] + B_upper_diagonal_blocks_init = B["B_upper_diagonal_blocks"] + B_lower_arrow_blocks_init = B["B_lower_arrow_blocks"] + B_upper_arrow_blocks_init = B["B_upper_arrow_blocks"] + B_arrow_tip_block_init = B["B_arrow_tip_block"] + + # Init device arrays + B_diagonal_blocks_cpu = np.empty_like(B_diagonal_blocks_init) + B_lower_diagonal_blocks_cpu = np.empty_like(B_lower_diagonal_blocks_init) + B_upper_diagonal_blocks_cpu = np.empty_like(B_upper_diagonal_blocks_init) + B_lower_arrow_blocks_cpu = np.empty_like(B_lower_arrow_blocks_init) + B_upper_arrow_blocks_cpu = np.empty_like(B_upper_arrow_blocks_init) + B_arrow_tip_block_cpu = np.empty_like(B_arrow_tip_block_init) + + t_ddbtasc = [] + t_ddbtasci = [] + + for i in range(n_warmups + n_iterations): + print(f"Iteration: {i+1}/{n_warmups+n_iterations}", flush=True) + + tic = time.perf_counter() + A_diagonal_blocks_cpu[:] = A_diagonal_blocks_init + A_lower_diagonal_blocks_cpu[:] = A_lower_diagonal_blocks_init + A_upper_diagonal_blocks_cpu[:] = A_upper_diagonal_blocks_init + A_lower_arrow_blocks_cpu[:] = A_lower_arrow_blocks_init + A_upper_arrow_blocks_cpu[:] = A_upper_arrow_blocks_init + A_arrow_tip_block_cpu[:] = A_arrow_tip_block_init + + if quadratic: + B_diagonal_blocks_cpu[:] = B_diagonal_blocks_init + B_lower_diagonal_blocks_cpu[:] = B_lower_diagonal_blocks_init + B_upper_diagonal_blocks_cpu[:] = B_upper_diagonal_blocks_init + B_lower_arrow_blocks_cpu[:] = B_lower_arrow_blocks_init + B_upper_arrow_blocks_cpu[:] = B_upper_arrow_blocks_init + B_arrow_tip_block_cpu[:] = B_arrow_tip_block_init + rhs = { + "B_diagonal_blocks": B_diagonal_blocks_cpu, + "B_lower_diagonal_blocks": B_lower_diagonal_blocks_cpu, + "B_upper_diagonal_blocks": B_upper_diagonal_blocks_cpu, + "B_lower_arrow_blocks": B_lower_arrow_blocks_cpu, + "B_upper_arrow_blocks": B_upper_arrow_blocks_cpu, + "B_arrow_tip_block": B_arrow_tip_block_cpu, + } + + toc = time.perf_counter() + print(f"Copying data from ref took: {toc - tic:.5f} sec", flush=True) + + tic = time.perf_counter() + ddbtasc( + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_upper_diagonal_blocks_cpu, + A_lower_arrow_blocks_cpu, + A_upper_arrow_blocks_cpu, + A_arrow_tip_block_cpu, + rhs=rhs if quadratic else None, + quadratic=quadratic, + ) + toc = time.perf_counter() + elapsed = toc - tic + print(f"t_ddbtasc took: {elapsed:.5f} sec", flush=True) + if i >= n_warmups: + t_ddbtasc.append(elapsed) + + tic = time.perf_counter() + ddbtasci( + A_diagonal_blocks_cpu, + A_lower_diagonal_blocks_cpu, + A_upper_diagonal_blocks_cpu, + A_lower_arrow_blocks_cpu, + A_upper_arrow_blocks_cpu, + A_arrow_tip_block_cpu, + rhs=rhs if quadratic else None, + quadratic=quadratic, + ) + toc = time.perf_counter() + elapsed = toc - tic + print(f"t_ddbtasci took: {elapsed:.5f} sec", flush=True) + if i >= n_warmups: + t_ddbtasci.append(elapsed) + + print(f"t_ddbtasc: {t_ddbtasc}", flush=True) + print(f"t_ddbtasci: {t_ddbtasci}", flush=True) + + print(f"avg t_ddbtasc: {np.mean(np.array(t_ddbtasc)):.5f} sec", flush=True) + print(f"avg t_ddbtasci: {np.mean(np.array(t_ddbtasci)):.5f} sec", flush=True) + print(f"avg total time: {np.mean(np.array(t_ddbtasc)) + np.mean(np.array(t_ddbtasci)):.5f} sec", flush=True) \ No newline at end of file diff --git a/perfbench/perfbench_ddbta_cpu.sh b/perfbench/perfbench_ddbta_cpu.sh new file mode 100644 index 00000000..5feb643d --- /dev/null +++ b/perfbench/perfbench_ddbta_cpu.sh @@ -0,0 +1,25 @@ +#!/bin/bash -l +#SBATCH --job-name="perfbench_ddbta_cpu" +#SBATCH --output=%x.%j.out +#SBATCH --error=%x.%j.err +#SBATCH --time=00:05:00 +#SBATCH --nodes=1 +#SBATCH --ntasks-per-node=1 +#SBATCH --cpus-per-task=32 +#SBATCH --partition=spr1tb +#SBATCH --hint=nomultithread + +unset SLURM_EXPORT_ENV + + +export OMP_NUM_THREADS=$SLURM_CPUS_PER_TASK +echo $OMP_NUM_THREADS + + +CPU_BIND="mask_cpu:0xffff00000000,0xffff000000000000" +CPU_BIND="${CPU_BIND},0xffff,0xffff0000" +CPU_BIND="${CPU_BIND},0xffff000000000000000000000000,0xffff0000000000000000000000000000" +CPU_BIND="${CPU_BIND},0xffff0000000000000000,0xffff00000000000000000000" + +srun python ./perfbench_ddbta_cpu.py --b 2048 --a 256 --n 8 --q 0 + diff --git a/src/serinv/__init__.py b/src/serinv/__init__.py index ca7f21dd..bad9c860 100644 --- a/src/serinv/__init__.py +++ b/src/serinv/__init__.py @@ -163,7 +163,7 @@ def _use_nccl(comm): return False -def _get_nccl_parameters(arr, comm, op: str): +def _get_nccl_parameters(arr, comm, rank, op: str): """Get the NCCL parameters for the given operation.""" if np.iscomplexobj(arr): factor = 2 @@ -172,8 +172,8 @@ def _get_nccl_parameters(arr, comm, op: str): if backend_flags["nccl_avail"]: if op == "allgather": - count = (arr.size // comm.size) * factor - displacement = count * comm.rank * arr.dtype.itemsize + count = (arr.size // comm.size()) * factor + displacement = count * rank * (arr.dtype.itemsize // factor) elif op == "allreduce": count = arr.size * factor displacement = 0 diff --git a/src/serinv/algs/ddbtasc.py b/src/serinv/algs/ddbtasc.py index 53a06869..3914c3c3 100644 --- a/src/serinv/algs/ddbtasc.py +++ b/src/serinv/algs/ddbtasc.py @@ -193,43 +193,47 @@ def _ddbtasc( A_arrow_tip_block: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 2xMM(bbb) + 1xMM(abb) + 2xMM(bba) + 1xMM(aba) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) for n_i in range(1, A_diagonal_blocks.shape[0]): # Invert previous diagonal block - A_diagonal_blocks[n_i - 1] = xp.linalg.inv(A_diagonal_blocks[n_i - 1]) + A_diagonal_blocks[n_i - 1] = xp.linalg.inv( + A_diagonal_blocks[n_i - 1] + ) # C: LU(b) + 2xTRSM(b) + temp = ( + A_diagonal_blocks[n_i - 1] @ A_upper_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) # Update next diagonal block A_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] - - A_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_diagonal_blocks[n_i - 1] - ) + A_diagonal_blocks[n_i] - A_lower_diagonal_blocks[n_i - 1] @ temp + ) # C: MM(bbb) # Update next lower arrow block A_lower_arrow_blocks[n_i] = ( - A_lower_arrow_blocks[n_i] - - A_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_diagonal_blocks[n_i - 1] - ) + A_lower_arrow_blocks[n_i] - A_lower_arrow_blocks[n_i - 1] @ temp + ) # C: MM(abb) + temp = A_diagonal_blocks[n_i - 1] @ A_upper_arrow_blocks[n_i - 1] # C: MM(bba) # Update next upper arrow block A_upper_arrow_blocks[n_i] = ( - A_upper_arrow_blocks[n_i] - - A_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_arrow_blocks[n_i - 1] - ) + A_upper_arrow_blocks[n_i] - A_lower_diagonal_blocks[n_i - 1] @ temp + ) # C: MM(bba) # Update tip arrow block A_arrow_tip_block[:] = ( - A_arrow_tip_block[:] - - A_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_arrow_blocks[n_i - 1] - ) + A_arrow_tip_block[:] - A_lower_arrow_blocks[n_i - 1] @ temp + ) # C: MM(aba) if invert_last_block: A_diagonal_blocks[-1] = xp.linalg.inv(A_diagonal_blocks[-1]) @@ -251,6 +255,16 @@ def _ddbtasc_permuted( A_lower_buffer_blocks: ArrayLike, A_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 6xMM(bbb) + 3xMM(abb) + 2xMM(bba) + 1xMM(aba) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) A_lower_buffer_blocks[0] = A_upper_diagonal_blocks[0] @@ -258,78 +272,58 @@ def _ddbtasc_permuted( for n_i in range(1, A_diagonal_blocks.shape[0] - 1): # Inverse current diagonal block - A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + A_diagonal_blocks[n_i] = xp.linalg.inv( + A_diagonal_blocks[n_i] + ) # C: LU(b) + 2xTRSM(b) # Update next diagonal block + temp_1 = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_diagonal_blocks[n_i + 1] = ( - A_diagonal_blocks[n_i + 1] - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i + 1] - temp_1 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Update next lower arrow block + temp_2 = A_lower_arrow_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(abb) A_lower_arrow_blocks[n_i + 1] = ( - A_lower_arrow_blocks[n_i + 1] - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + A_lower_arrow_blocks[n_i + 1] - temp_2 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(abb) # Update next upper arrow block A_upper_arrow_blocks[n_i + 1] = ( - A_upper_arrow_blocks[n_i + 1] - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_arrow_blocks[n_i] - ) + A_upper_arrow_blocks[n_i + 1] - temp_1 @ A_upper_arrow_blocks[n_i] + ) # C: MM(bba) # Update tip arrow block A_arrow_tip_block[:] = ( - A_arrow_tip_block[:] - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_arrow_blocks[n_i] - ) + A_arrow_tip_block[:] - temp_2 @ A_upper_arrow_blocks[n_i] + ) # C: MM(aba) # --- Update of working buffer linked to permuted partition # Lower buffer block + temp_3 = A_lower_buffer_blocks[n_i - 1] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_lower_buffer_blocks[n_i] = ( - -A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + -temp_3 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Upper buffer block A_upper_buffer_blocks[n_i] = ( - -A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + -temp_1 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # 0-diagonal block (first) A_diagonal_blocks[0] = ( - A_diagonal_blocks[0] - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + A_diagonal_blocks[0] - temp_3 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # 0-lower arrow block (first) A_lower_arrow_blocks[0] = ( - A_lower_arrow_blocks[0] - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + A_lower_arrow_blocks[0] - temp_2 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(abb) # 0-upper arrow block (first) A_upper_arrow_blocks[0] = ( - A_upper_arrow_blocks[0] - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_arrow_blocks[n_i] - ) + A_upper_arrow_blocks[0] - temp_3 @ A_upper_arrow_blocks[n_i] + ) # C: MM(bba) def _ddbtasc_quadratic( @@ -347,115 +341,109 @@ def _ddbtasc_quadratic( B_arrow_tip_block: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 8xMM(bbb) + 5xMM(abb) + 5xMM(bba) + 4xMM(aba) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) for n_i in range(1, A_diagonal_blocks.shape[0]): # Invert previous diagonal block of A - A_diagonal_blocks[n_i - 1] = xp.linalg.inv(A_diagonal_blocks[n_i - 1]) + A_diagonal_blocks[n_i - 1] = xp.linalg.inv( + A_diagonal_blocks[n_i - 1] + ) # C: LU(b) + 2xTRSM(b) # Update next diagonal block + temp_a_diag = ( + A_lower_diagonal_blocks[n_i - 1] @ A_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) + temp_a_diag_conjt = ( + temp_a_diag.conj().T + ) # A_diagonal_blocks[n_i - 1].conj().T @ A_lower_diagonal_blocks[n_i - 1].conj().T + A_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] - - A_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_diagonal_blocks[n_i - 1] - ) + A_diagonal_blocks[n_i] - temp_a_diag @ A_upper_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) # Update next lower arrow block + temp_a_arrow = ( + A_lower_arrow_blocks[n_i - 1] @ A_diagonal_blocks[n_i - 1] + ) # C: MM(abb) + temp_a_arrow_conjt = ( + temp_a_arrow.conj().T + ) # A_diagonal_blocks[n_i - 1].conj().T @ A_lower_arrow_blocks[n_i - 1].conj().T + A_lower_arrow_blocks[n_i] = ( - A_lower_arrow_blocks[n_i] - - A_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_diagonal_blocks[n_i - 1] - ) + A_lower_arrow_blocks[n_i] - temp_a_arrow @ A_upper_diagonal_blocks[n_i - 1] + ) # C: MM(abb) # Update next upper arrow block A_upper_arrow_blocks[n_i] = ( - A_upper_arrow_blocks[n_i] - - A_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_arrow_blocks[n_i - 1] - ) + A_upper_arrow_blocks[n_i] - temp_a_diag @ A_upper_arrow_blocks[n_i - 1] + ) # C: MM(bba) # Update tip arrow block A_arrow_tip_block[:] = ( - A_arrow_tip_block[:] - - A_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ A_upper_arrow_blocks[n_i - 1] - ) + A_arrow_tip_block[:] - temp_a_arrow @ A_upper_arrow_blocks[n_i - 1] + ) # C: MM(aba) # --- Xl --- # Inverse previous diagonal block of B B_diagonal_blocks[n_i - 1] = ( A_diagonal_blocks[n_i - 1] @ B_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1].T - ) + @ A_diagonal_blocks[n_i - 1].conj().T + ) # C: 2xMM(bbb) + temp_b_diag = ( + B_diagonal_blocks[n_i - 1] @ A_lower_diagonal_blocks[n_i - 1].conj().T + ) # C: MM(bbb) B_diagonal_blocks[n_i] = ( B_diagonal_blocks[n_i] - + A_lower_diagonal_blocks[n_i - 1] - @ B_diagonal_blocks[n_i - 1] - @ A_lower_diagonal_blocks[n_i - 1].T - - B_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1].T - @ A_lower_diagonal_blocks[n_i - 1].T - - A_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ B_upper_diagonal_blocks[n_i - 1] - ) - + + A_lower_diagonal_blocks[n_i - 1] @ temp_b_diag + - B_lower_diagonal_blocks[n_i - 1] @ temp_a_diag_conjt + - temp_a_diag @ B_upper_diagonal_blocks[n_i - 1] + ) # C: 3xMM(bbb) + + temp_b_arrow = ( + B_diagonal_blocks[n_i - 1] @ A_lower_arrow_blocks[n_i - 1].conj().T + ) # C: MM(bba) B_upper_arrow_blocks[n_i] = ( B_upper_arrow_blocks[n_i] - + A_lower_diagonal_blocks[n_i - 1] - @ B_diagonal_blocks[n_i - 1] - @ A_lower_arrow_blocks[n_i - 1].T - - B_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1].T - @ A_lower_arrow_blocks[n_i - 1].T - - A_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ B_upper_arrow_blocks[n_i - 1] - ) + + A_lower_diagonal_blocks[n_i - 1] @ temp_b_arrow + - B_lower_diagonal_blocks[n_i - 1] @ temp_a_arrow_conjt + - temp_a_diag @ B_upper_arrow_blocks[n_i - 1] + ) # C: 3xMM(bba) B_lower_arrow_blocks[n_i] = ( B_lower_arrow_blocks[n_i] - + A_lower_arrow_blocks[n_i - 1] - @ B_diagonal_blocks[n_i - 1] - @ A_lower_diagonal_blocks[n_i - 1].T - - B_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1].T - @ A_lower_diagonal_blocks[n_i - 1].T - - A_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ B_upper_diagonal_blocks[n_i - 1] - ) + + A_lower_arrow_blocks[n_i - 1] @ temp_b_diag + - B_lower_arrow_blocks[n_i - 1] @ temp_a_diag_conjt + - temp_a_arrow @ B_upper_diagonal_blocks[n_i - 1] + ) # C: 3xMM(abb) B_arrow_tip_block[:, :] = ( B_arrow_tip_block - + A_lower_arrow_blocks[n_i - 1] - @ B_diagonal_blocks[n_i - 1] - @ A_lower_arrow_blocks[n_i - 1].T - - B_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1].T - @ A_lower_arrow_blocks[n_i - 1].T - - A_lower_arrow_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1] - @ B_upper_arrow_blocks[n_i - 1] - ) + + A_lower_arrow_blocks[n_i - 1] @ temp_b_arrow + - B_lower_arrow_blocks[n_i - 1] @ temp_a_arrow_conjt + - temp_a_arrow @ B_upper_arrow_blocks[n_i - 1] + ) # C: 3xMM(aba) if invert_last_block: A_diagonal_blocks[-1] = xp.linalg.inv(A_diagonal_blocks[-1]) B_diagonal_blocks[-1] = ( - A_diagonal_blocks[-1] @ B_diagonal_blocks[-1] @ A_diagonal_blocks[-1].T + A_diagonal_blocks[-1] @ B_diagonal_blocks[-1] @ A_diagonal_blocks[-1].conj().T ) + temp_a_arrow = A_lower_arrow_blocks[-1] @ A_diagonal_blocks[-1] A_arrow_tip_block[:] = xp.linalg.inv( - A_arrow_tip_block[:] - - A_lower_arrow_blocks[-1] - @ A_diagonal_blocks[-1] - @ A_upper_arrow_blocks[-1] + A_arrow_tip_block[:] - temp_a_arrow @ A_upper_arrow_blocks[-1] ) B_arrow_tip_block[:] = ( A_arrow_tip_block[:] @@ -463,15 +451,11 @@ def _ddbtasc_quadratic( B_arrow_tip_block[:] + A_lower_arrow_blocks[-1] @ B_diagonal_blocks[-1] - @ A_lower_arrow_blocks[-1].T - - B_lower_arrow_blocks[-1] - @ A_diagonal_blocks[-1].T - @ A_lower_arrow_blocks[-1].T - - A_lower_arrow_blocks[-1] - @ A_diagonal_blocks[-1] - @ B_upper_arrow_blocks[-1] + @ A_lower_arrow_blocks[-1].conj().T + - B_lower_arrow_blocks[-1] @ temp_a_arrow.conj().T + - temp_a_arrow @ B_upper_arrow_blocks[-1] ) - @ A_arrow_tip_block[:].T + @ A_arrow_tip_block[:].conj().T ) @@ -493,6 +477,16 @@ def _ddbtasc_quadratic_permuted( B_lower_buffer_blocks: ArrayLike, B_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 22xMM(bbb) + 10xMM(abb) + 8xMM(bba) + 4xMM(aba) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) A_lower_buffer_blocks[0] = A_upper_diagonal_blocks[0] @@ -503,208 +497,145 @@ def _ddbtasc_quadratic_permuted( for n_i in range(1, A_diagonal_blocks.shape[0] - 1): # Inverse current diagonal block - A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + A_diagonal_blocks[n_i] = xp.linalg.inv( + A_diagonal_blocks[n_i] + ) # C: LU(b) + 2xTRSM(b) # Update next diagonal block + temp_1 = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(bbb) + temp_1_conjt = temp_1.conj().T A_diagonal_blocks[n_i + 1] = ( - A_diagonal_blocks[n_i + 1] - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i + 1] - temp_1 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Update next lower arrow block + temp_a_arrow = A_lower_arrow_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(abb) + temp_a_arrow_conjt = ( + temp_a_arrow.conj().T + ) # A_diagonal_blocks[n_i].conj().T @ A_lower_arrow_blocks[n_i].conj().T + A_lower_arrow_blocks[n_i + 1] = ( - A_lower_arrow_blocks[n_i + 1] - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + A_lower_arrow_blocks[n_i + 1] - temp_a_arrow @ A_upper_diagonal_blocks[n_i] + ) # C: MM(abb) # Update next upper arrow block A_upper_arrow_blocks[n_i + 1] = ( - A_upper_arrow_blocks[n_i + 1] - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_arrow_blocks[n_i] - ) + A_upper_arrow_blocks[n_i + 1] - temp_1 @ A_upper_arrow_blocks[n_i] + ) # C: MM(bba) # Update tip arrow block A_arrow_tip_block[:] = ( - A_arrow_tip_block[:] - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_arrow_blocks[n_i] - ) + A_arrow_tip_block[:] - temp_a_arrow @ A_upper_arrow_blocks[n_i] + ) # C: MM(aba) # --- Update of working buffer linked to permuted partition # Lower buffer block + temp_2 = A_lower_buffer_blocks[n_i - 1] @ A_diagonal_blocks[n_i] # C: MM(bbb) + temp_2_conjt = temp_2.conj().T A_lower_buffer_blocks[n_i] = ( - -A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + -temp_2 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Upper buffer block A_upper_buffer_blocks[n_i] = ( - -A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + -temp_1 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # 0-diagonal block (first) A_diagonal_blocks[0] = ( - A_diagonal_blocks[0] - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + A_diagonal_blocks[0] - temp_2 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # 0-lower arrow block (first) A_lower_arrow_blocks[0] = ( - A_lower_arrow_blocks[0] - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + A_lower_arrow_blocks[0] - temp_a_arrow @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(abb) # 0-upper arrow block (first) A_upper_arrow_blocks[0] = ( - A_upper_arrow_blocks[0] - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_arrow_blocks[n_i] - ) + A_upper_arrow_blocks[0] - temp_2 @ A_upper_arrow_blocks[n_i] + ) # C: MM(bba) # --- Xl --- # Inverse current diagonal block B_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i].T - ) + A_diagonal_blocks[n_i] + @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T + ) # C: 2xMM(bbb) # Update next diagonal block + temp_3 = A_lower_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] # C: MM(bbb) B_diagonal_blocks[n_i + 1] = ( B_diagonal_blocks[n_i + 1] - + A_lower_diagonal_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_diagonal_blocks[n_i].T - - B_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_diagonal_blocks[n_i].T - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_diagonal_blocks[n_i] - ) + + temp_3 @ A_lower_diagonal_blocks[n_i].conj().T + - B_lower_diagonal_blocks[n_i] @ temp_1_conjt + - temp_1 @ B_upper_diagonal_blocks[n_i] + ) # C: 3xMM(bbb) # Update next lower arrow block + temp_4 = A_lower_arrow_blocks[n_i] @ B_diagonal_blocks[n_i] # C: MM(abb) B_lower_arrow_blocks[n_i + 1] = ( B_lower_arrow_blocks[n_i + 1] - + A_lower_arrow_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_diagonal_blocks[n_i].T - - B_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_diagonal_blocks[n_i].T - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_diagonal_blocks[n_i] - ) + + temp_4 @ A_lower_diagonal_blocks[n_i].conj().T + - B_lower_arrow_blocks[n_i] @ temp_1_conjt + - temp_a_arrow @ B_upper_diagonal_blocks[n_i] + ) # C: 3xMM(abb) # Update next upper arrow block B_upper_arrow_blocks[n_i + 1] = ( B_upper_arrow_blocks[n_i + 1] - + A_lower_diagonal_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_arrow_blocks[n_i].T - - B_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_arrow_blocks[n_i].T - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_arrow_blocks[n_i] - ) + + temp_3 @ A_lower_arrow_blocks[n_i].conj().T + - B_lower_diagonal_blocks[n_i] @ temp_a_arrow_conjt + - temp_1 @ B_upper_arrow_blocks[n_i] + ) # C: 3xMM(bba) # Update tip arrow block B_arrow_tip_block[:, :] = ( B_arrow_tip_block - + A_lower_arrow_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_arrow_blocks[n_i].T - - B_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_arrow_blocks[n_i].T - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_arrow_blocks[n_i] - ) + + temp_4 @ A_lower_arrow_blocks[n_i].conj().T + - B_lower_arrow_blocks[n_i] @ temp_a_arrow_conjt + - temp_a_arrow @ B_upper_arrow_blocks[n_i] + ) # C: 3xMM(aba) # --- Update of working buffer linked to permuted partition # Lower buffer block + temp_5 = A_lower_buffer_blocks[n_i - 1] @ B_diagonal_blocks[n_i] # C: MM(bbb) B_lower_buffer_blocks[n_i] = ( B_lower_buffer_blocks[n_i] - + A_lower_buffer_blocks[n_i - 1] - @ B_diagonal_blocks[n_i] - @ A_lower_diagonal_blocks[n_i].T - - B_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i].T - @ A_lower_diagonal_blocks[n_i].T - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ B_upper_diagonal_blocks[n_i] - ) + + temp_5 @ A_lower_diagonal_blocks[n_i].conj().T + - B_lower_buffer_blocks[n_i - 1] @ temp_1_conjt + - temp_2 @ B_upper_diagonal_blocks[n_i] + ) # C: 3xMM(bbb) # Upper buffer block B_upper_buffer_blocks[n_i] = ( B_upper_buffer_blocks[n_i] - + A_lower_diagonal_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_buffer_blocks[n_i - 1].T - - B_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_buffer_blocks[n_i - 1].T - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_buffer_blocks[n_i - 1] - ) + + temp_3 @ A_lower_buffer_blocks[n_i - 1].conj().T + - B_lower_diagonal_blocks[n_i] @ temp_2_conjt + - temp_1 @ B_upper_buffer_blocks[n_i - 1] + ) # C: 3xMM(bbb) # 0-diagonal block (first) B_diagonal_blocks[0] = ( B_diagonal_blocks[0] - + A_lower_buffer_blocks[n_i - 1] - @ B_diagonal_blocks[n_i] - @ A_lower_buffer_blocks[n_i - 1].T - - B_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i].T - @ A_lower_buffer_blocks[n_i - 1].T - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ B_upper_buffer_blocks[n_i - 1] - ) + + temp_5 @ A_lower_buffer_blocks[n_i - 1].conj().T + - B_lower_buffer_blocks[n_i - 1] @ temp_2_conjt + - temp_2 @ B_upper_buffer_blocks[n_i - 1] + ) # C: 3xMM(bbb) # 0-lower arrow block (first) B_lower_arrow_blocks[0] = ( B_lower_arrow_blocks[0] - + A_lower_arrow_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_buffer_blocks[n_i - 1].T - - B_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_buffer_blocks[n_i - 1].T - - A_lower_arrow_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_buffer_blocks[n_i - 1] - ) + + temp_4 @ A_lower_buffer_blocks[n_i - 1].conj().T + - B_lower_arrow_blocks[n_i] @ temp_2_conjt + - temp_a_arrow @ B_upper_buffer_blocks[n_i - 1] + ) # C: 3xMM(abb) # 0-upper arrow block (first) B_upper_arrow_blocks[0] = ( B_upper_arrow_blocks[0] - + A_lower_buffer_blocks[n_i - 1] - @ B_diagonal_blocks[n_i] - @ A_lower_arrow_blocks[n_i].T - - B_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i].T - @ A_lower_arrow_blocks[n_i].T - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ B_upper_arrow_blocks[n_i] - ) + + temp_5 @ A_lower_arrow_blocks[n_i].conj().T + - B_lower_buffer_blocks[n_i - 1] @ temp_a_arrow_conjt + - temp_2 @ B_upper_arrow_blocks[n_i] + ) # C: 3xMM(bba) diff --git a/src/serinv/algs/ddbtasci.py b/src/serinv/algs/ddbtasci.py index 266c7a94..a733be7e 100644 --- a/src/serinv/algs/ddbtasci.py +++ b/src/serinv/algs/ddbtasci.py @@ -193,6 +193,16 @@ def _ddbtasci( A_arrow_tip_block: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 7xMM(bbb) + 2xMM(bba) + 1xMM(baa) + 2xMM(abb) + 1xMM(aab) + 3xMM(bab) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) if A_diagonal_blocks.shape[0] > 1: @@ -222,38 +232,36 @@ def _ddbtasci( B1[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ A_lower_arrow_blocks[n_i + 1] - ) + ) # C: MM(bbb) + MM(bab) B2[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block[:, :] - ) + + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block + ) # C: MM(bba) + MM(baa) C1[:, :] = ( A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + A_upper_arrow_blocks[n_i + 1] @ A_lower_arrow_blocks[n_i] - ) + ) # C: MM(bbb) + MM(bab) C2[:, :] = ( A_lower_arrow_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] - + A_arrow_tip_block[:, :] @ A_lower_arrow_blocks[n_i] - ) + + A_arrow_tip_block @ A_lower_arrow_blocks[n_i] + ) # C: MM(abb) + MM(aab) - A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1[:, :] - A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B2[:, :] + A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1 # C: MM(bbb) + A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B2 # C: MM(bba) D1[:, :] = A_lower_diagonal_blocks[n_i] D2[:, :] = A_lower_arrow_blocks[n_i] - A_lower_diagonal_blocks[n_i] = -C1[:, :] @ A_diagonal_blocks[n_i] - A_lower_arrow_blocks[n_i] = -C2[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i] = -C1 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_arrow_blocks[n_i] = -C2 @ A_diagonal_blocks[n_i] # C: MM(abb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] - + A_diagonal_blocks[n_i] - @ (B1[:, :] @ D1[:, :] + B2[:, :] @ D2[:, :]) - @ A_diagonal_blocks[n_i] - ) + + A_diagonal_blocks[n_i] @ (B1 @ D1 + B2 @ D2) @ A_diagonal_blocks[n_i] + ) # C: 3xMM(bbb) + MM(bab) def _ddbtasci_permuted( @@ -266,6 +274,16 @@ def _ddbtasci_permuted( A_lower_buffer_blocks: ArrayLike, A_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 14xMM(bbb) + 5xMM(bba) + 1xMM(baa) + 3xMM(abb) + 1xMM(aab) + 5xMM(bab) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) B1 = xp.empty_like(A_lower_diagonal_blocks[0]) @@ -285,56 +303,56 @@ def _ddbtasci_permuted( A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ A_lower_arrow_blocks[n_i + 1] + A_upper_buffer_blocks[n_i - 1] @ A_lower_buffer_blocks[n_i] - ) + ) # C: 2xMM(bbb) + MM(bab) B2[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_buffer_blocks[n_i] + A_upper_buffer_blocks[n_i - 1] @ A_diagonal_blocks[0] + A_upper_arrow_blocks[n_i] @ A_lower_arrow_blocks[0] - ) + ) # C: 2xMM(bbb) + MM(bab) B3[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_arrow_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block[:, :] + A_upper_buffer_blocks[n_i - 1] @ A_upper_arrow_blocks[0] - ) + ) # C: 2xMM(bba) + MM(baa) C1[:, :] = ( A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + A_upper_arrow_blocks[n_i + 1] @ A_lower_arrow_blocks[n_i] + A_upper_buffer_blocks[n_i] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) + MM(bab) C2[:, :] = ( A_lower_buffer_blocks[n_i] @ A_lower_diagonal_blocks[n_i] + A_diagonal_blocks[0] @ A_lower_buffer_blocks[n_i - 1] + A_upper_arrow_blocks[0] @ A_lower_arrow_blocks[n_i] - ) + ) # C: 2xMM(bba) + MM(bab) C3[:, :] = ( A_lower_arrow_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] - + A_arrow_tip_block[:, :] @ A_lower_arrow_blocks[n_i] + + A_arrow_tip_block @ A_lower_arrow_blocks[n_i] + A_lower_arrow_blocks[0] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(abb) + MM(aab) - A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1[:, :] - A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2[:, :] - A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B3[:, :] + A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1 # C: MM(bbb) + A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2 # C: MM(bbb) + A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B3 # C: MM(bba) D1[:, :] = A_lower_diagonal_blocks[n_i] D2[:, :] = A_lower_buffer_blocks[n_i - 1] D3[:, :] = A_lower_arrow_blocks[n_i] - A_lower_diagonal_blocks[n_i] = -C1[:, :] @ A_diagonal_blocks[n_i] - A_lower_buffer_blocks[n_i - 1] = -C2[:, :] @ A_diagonal_blocks[n_i] - A_lower_arrow_blocks[n_i] = -C3[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i] = -C1 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_buffer_blocks[n_i - 1] = -C2 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_arrow_blocks[n_i] = -C3 @ A_diagonal_blocks[n_i] # C: MM(abb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] + A_diagonal_blocks[n_i] - @ (B1[:, :] @ D1[:, :] + B2[:, :] @ D2[:, :] + B3[:, :] @ D3[:, :]) + @ (B1 @ D1 + B2 @ D2 + B3 @ D3) @ A_diagonal_blocks[n_i] - ) + ) # C: 4xMM(bbb) + MM(bab) A_lower_diagonal_blocks[0] = A_upper_buffer_blocks[0] A_upper_diagonal_blocks[0] = A_lower_buffer_blocks[0] @@ -355,6 +373,16 @@ def _ddbtasci_quadratic( B_arrow_tip_block: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 34xMM(bbb) + 9xMM(bba) + 4xMM(baa) + 6xMM(abb) + 3xMM(aab) + 13xMM(bab) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) if A_diagonal_blocks.shape[0] > 1: @@ -376,79 +404,69 @@ def _ddbtasci_quadratic( # --- Xl --- if invert_last_block: B2[:, :] = A_diagonal_blocks[-1] @ A_upper_arrow_blocks[-1] - C2[:, :] = A_upper_arrow_blocks[-1].T @ A_diagonal_blocks[-1].T - D2[:, :] = ( - A_arrow_tip_block[:, :] @ A_lower_arrow_blocks[-1] @ B_diagonal_blocks[-1] - ) + C2[:, :] = B2.conj().T + D2[:, :] = A_arrow_tip_block @ A_lower_arrow_blocks[-1] @ B_diagonal_blocks[-1] temp_B_13[:, :] = B_upper_arrow_blocks[-1] temp_B_31[:, :] = B_lower_arrow_blocks[-1] B_upper_arrow_blocks[-1] = ( - -B2[:, :] @ B_arrow_tip_block[:, :] + -B2 @ B_arrow_tip_block - B_diagonal_blocks[-1] - @ A_lower_arrow_blocks[-1].T - @ A_arrow_tip_block[:, :].T + @ A_lower_arrow_blocks[-1].conj().T + @ A_arrow_tip_block.conj().T + A_diagonal_blocks[-1] @ B_upper_arrow_blocks[-1] - @ A_arrow_tip_block[:, :].T + @ A_arrow_tip_block.conj().T ) B_lower_arrow_blocks[-1] = ( - -B_arrow_tip_block[:, :] @ C2[:, :] - - D2[:, :] - + A_arrow_tip_block[:, :] + -B_arrow_tip_block @ C2 + - D2 + + A_arrow_tip_block @ B_lower_arrow_blocks[-1] - @ A_diagonal_blocks[-1].T + @ A_diagonal_blocks[-1].conj().T ) B_diagonal_blocks[-1] = ( B_diagonal_blocks[-1] - + B2[:, :] @ B_arrow_tip_block[:, :] @ C2[:, :] - + B2[:, :] @ D2[:, :] - + B_diagonal_blocks[-1].T - @ A_lower_arrow_blocks[-1].T - @ A_arrow_tip_block[:, :].T - @ C2[:, :] - - B2[:, :] - @ A_arrow_tip_block[:, :] - @ temp_B_31[:, :] - @ A_diagonal_blocks[-1].T - - A_diagonal_blocks[-1] - @ temp_B_13[:, :] - @ A_arrow_tip_block[:, :].T - @ C2[:, :] + + B2 @ B_arrow_tip_block @ C2 + + B2 @ D2 + + B_diagonal_blocks[-1] + @ A_lower_arrow_blocks[-1].conj().T + @ A_arrow_tip_block.conj().T + @ C2 + - B2 @ A_arrow_tip_block @ temp_B_31 @ A_diagonal_blocks[-1].conj().T + - A_diagonal_blocks[-1] @ temp_B_13 @ A_arrow_tip_block.conj().T @ C2 ) # --- Xr --- A_lower_arrow_blocks[-1] = ( -A_arrow_tip_block[:] @ A_lower_arrow_blocks[-1] @ A_diagonal_blocks[-1] ) - A_upper_arrow_blocks[-1] = -B2[:, :] @ A_arrow_tip_block[:] - A_diagonal_blocks[-1] = ( - A_diagonal_blocks[-1] - B2[:, :] @ A_lower_arrow_blocks[-1] - ) + A_upper_arrow_blocks[-1] = -B2 @ A_arrow_tip_block[:] + A_diagonal_blocks[-1] = A_diagonal_blocks[-1] - B2 @ A_lower_arrow_blocks[-1] for n_i in range(A_diagonal_blocks.shape[0] - 2, -1, -1): B1[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ A_lower_arrow_blocks[n_i + 1] - ) + ) # C: MM(bbb) + MM(bab) B2[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block[:, :] - ) + + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block + ) # C: MM(bba) + MM(baa) C1[:, :] = ( A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + A_upper_arrow_blocks[n_i + 1] @ A_lower_arrow_blocks[n_i] - ) + ) # C: MM(bbb) + MM(bab) C2[:, :] = ( A_lower_arrow_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] - + A_arrow_tip_block[:, :] @ A_lower_arrow_blocks[n_i] - ) + + A_arrow_tip_block @ A_lower_arrow_blocks[n_i] + ) # C: MM(abb) + MM(aab) # --- Xl --- temp_B_12[:, :] = B_upper_diagonal_blocks[n_i] @@ -461,53 +479,55 @@ def _ddbtasci_quadratic( @ ( A_upper_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ B_lower_arrow_blocks[n_i + 1] - ) - - B_diagonal_blocks[n_i] @ C1[:, :].T + ) # C: 2xMM(bbb) + MM(bab) + - B_diagonal_blocks[n_i] @ C1.conj().T # C: MM(bbb) + A_diagonal_blocks[n_i] @ ( - B_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1].T - + B_upper_arrow_blocks[n_i] @ A_upper_arrow_blocks[n_i + 1].T - ) + B_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1].conj().T + + B_upper_arrow_blocks[n_i] @ A_upper_arrow_blocks[n_i + 1].conj().T + ) # C: 2xMM(bbb) + MM(bab) ) + B_upper_arrow_blocks[n_i] = ( -A_diagonal_blocks[n_i] @ ( A_upper_diagonal_blocks[n_i] @ B_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block[:, :] - ) - - B_diagonal_blocks[n_i] @ C2[:, :].T + + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block + ) # C: 2xMM(bba) + MM(baa) + - B_diagonal_blocks[n_i] @ C2.conj().T # C: MM(bba) + A_diagonal_blocks[n_i] @ ( - temp_B_12[:, :] @ A_lower_arrow_blocks[n_i + 1].T - + B_upper_arrow_blocks[n_i] @ A_arrow_tip_block[:, :].T - ) + temp_B_12 @ A_lower_arrow_blocks[n_i + 1].conj().T + + B_upper_arrow_blocks[n_i] @ A_arrow_tip_block.conj().T + ) # C: 2xMM(bba) + MM(baa) ) B_lower_diagonal_blocks[n_i] = ( -( - B_diagonal_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].T - + B_upper_arrow_blocks[n_i + 1] @ A_upper_arrow_blocks[n_i].T + B_diagonal_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].conj().T + + B_upper_arrow_blocks[n_i + 1] @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T - - (C1[:, :]) @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T # C: 2xMM(bbb) + MM(bab) + - C1 @ B_diagonal_blocks[n_i] # C: MM(bbb) + ( A_diagonal_blocks[n_i + 1] @ B_lower_diagonal_blocks[n_i] + A_upper_arrow_blocks[n_i + 1] @ B_lower_arrow_blocks[n_i] ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 2xMM(bbb) + MM(bab) ) + B_lower_arrow_blocks[n_i] = ( -( - B_lower_arrow_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].T - + B_arrow_tip_block[:, :] @ A_upper_arrow_blocks[n_i].T + B_lower_arrow_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].conj().T + + B_arrow_tip_block @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T - - (C2[:, :]) @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T # C: MM(abb) + MM(aab) + - C2 @ B_diagonal_blocks[n_i] # C: MM(abb) + ( - A_lower_arrow_blocks[n_i + 1] @ temp_B_21[:, :] - + A_arrow_tip_block[:, :] @ B_lower_arrow_blocks[n_i] + A_lower_arrow_blocks[n_i + 1] @ temp_B_21 + + A_arrow_tip_block @ B_lower_arrow_blocks[n_i] ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 2xMM(abb) + MM(aab) ) B_diagonal_blocks[n_i] = ( @@ -518,61 +538,60 @@ def _ddbtasci_quadratic( A_upper_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ B_lower_arrow_blocks[n_i + 1] ) - @ A_upper_diagonal_blocks[n_i].T + @ A_upper_diagonal_blocks[n_i].conj().T + ( A_upper_diagonal_blocks[n_i] @ B_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block[:, :] + + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block ) - @ A_upper_arrow_blocks[n_i].T + @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i] + .conj() + .T # C: 4xMM(bbb) + 2xMM(bab) + MM(bba) + M(baa) + A_diagonal_blocks[n_i] + @ (B1 @ A_lower_diagonal_blocks[n_i] + B2 @ A_lower_arrow_blocks[n_i]) + @ B_diagonal_blocks[n_i] # C: 3xMM(bbb) + MM(bab) + + B_diagonal_blocks[n_i] @ ( - (B1[:, :]) @ A_lower_diagonal_blocks[n_i] - + (B2[:, :]) @ A_lower_arrow_blocks[n_i] + C1.conj().T @ A_upper_diagonal_blocks[n_i].conj().T + + C2.conj().T @ A_upper_arrow_blocks[n_i].conj().T ) - @ B_diagonal_blocks[n_i] - + B_diagonal_blocks[n_i].T - @ ( - C1[:, :].T @ A_upper_diagonal_blocks[n_i].T - + C2[:, :].T @ A_upper_arrow_blocks[n_i].T - ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(bbb) + MM(bab) - A_diagonal_blocks[n_i] - @ ((B1[:, :]) @ temp_B_21 + (B2[:, :]) @ temp_B_31) - @ A_diagonal_blocks[n_i].T + @ (B1 @ temp_B_21 + B2 @ temp_B_31) + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(bbb) + MM(bab) - A_diagonal_blocks[n_i] @ ( ( - temp_B_12 @ A_diagonal_blocks[n_i + 1].T - + temp_B_13 @ A_upper_arrow_blocks[n_i + 1].T + temp_B_12 @ A_diagonal_blocks[n_i + 1].conj().T + + temp_B_13 @ A_upper_arrow_blocks[n_i + 1].conj().T ) - @ A_upper_diagonal_blocks[n_i].T + @ A_upper_diagonal_blocks[n_i].conj().T + ( - temp_B_12 @ A_lower_arrow_blocks[n_i + 1].T - + temp_B_13 @ A_arrow_tip_block[:, :].T + temp_B_12 @ A_lower_arrow_blocks[n_i + 1].conj().T + + temp_B_13 @ A_arrow_tip_block.conj().T ) - @ A_upper_arrow_blocks[n_i].T + @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i] + .conj() + .T # C: 4xMM(bbb) + 2xMM(bab) + MM(bba) + MM(baa) ) # --- Xr --- - A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1[:, :] - A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B2[:, :] + A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1 # C: MM(bbb) + A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B2 # C: MM(bba) D1[:, :] = A_lower_diagonal_blocks[n_i] D2[:, :] = A_lower_arrow_blocks[n_i] - A_lower_diagonal_blocks[n_i] = -C1[:, :] @ A_diagonal_blocks[n_i] - A_lower_arrow_blocks[n_i] = -C2[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i] = -C1 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_arrow_blocks[n_i] = -C2 @ A_diagonal_blocks[n_i] # C: MM(abb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] - + A_diagonal_blocks[n_i] - @ (B1[:, :] @ D1[:, :] + B2[:, :] @ D2[:, :]) - @ A_diagonal_blocks[n_i] - ) + + A_diagonal_blocks[n_i] @ (B1 @ D1 + B2 @ D2) @ A_diagonal_blocks[n_i] + ) # C: 3xMM(bbb) + MM(bab) def _ddbtasci_quadratic_permuted( @@ -593,6 +612,16 @@ def _ddbtasci_quadratic_permuted( B_lower_buffer_blocks: ArrayLike, B_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 72xMM(bbb) + 14xMM(bba) + 5xMM(baa) + 10xMM(abb) + 3xMM(aab) + 22xMM(bab) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) B1 = xp.empty_like(A_lower_diagonal_blocks[0]) @@ -620,37 +649,37 @@ def _ddbtasci_quadratic_permuted( A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ A_lower_arrow_blocks[n_i + 1] + A_upper_buffer_blocks[n_i - 1] @ A_lower_buffer_blocks[n_i] - ) + ) # C: 2xMM(bbb) + MM(bab) B2[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_buffer_blocks[n_i] + A_upper_buffer_blocks[n_i - 1] @ A_diagonal_blocks[0] + A_upper_arrow_blocks[n_i] @ A_lower_arrow_blocks[0] - ) + ) # C: 2xMM(bbb) + MM(bab) B3[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block[:, :] + + A_upper_arrow_blocks[n_i] @ A_arrow_tip_block + A_upper_buffer_blocks[n_i - 1] @ A_upper_arrow_blocks[0] - ) + ) # C: 2xMM(bba) + MM(baa) C1[:, :] = ( A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + A_upper_arrow_blocks[n_i + 1] @ A_lower_arrow_blocks[n_i] + A_upper_buffer_blocks[n_i] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) + MM(bab) C2[:, :] = ( A_lower_buffer_blocks[n_i] @ A_lower_diagonal_blocks[n_i] + A_diagonal_blocks[0] @ A_lower_buffer_blocks[n_i - 1] + A_upper_arrow_blocks[0] @ A_lower_arrow_blocks[n_i] - ) + ) # C: 2xMM(bbb) + MM(bab) C3[:, :] = ( A_lower_arrow_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] - + A_arrow_tip_block[:, :] @ A_lower_arrow_blocks[n_i] + + A_arrow_tip_block @ A_lower_arrow_blocks[n_i] + A_lower_arrow_blocks[0] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(abb) + MM(aab) # --- Xl --- temp_B_12[:, :] = B_upper_diagonal_blocks[n_i] @@ -667,90 +696,94 @@ def _ddbtasci_quadratic_permuted( A_upper_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i + 1] + A_upper_arrow_blocks[n_i] @ B_lower_arrow_blocks[n_i + 1] + A_upper_buffer_blocks[n_i - 1] @ B_lower_buffer_blocks[n_i] - ) - - B_diagonal_blocks[n_i] @ C1[:, :].T + ) # C: 3xMM(bbb) + MM(bab) + - B_diagonal_blocks[n_i] @ C1.conj().T # C: MM(bbb) + A_diagonal_blocks[n_i] @ ( - temp_B_12[:, :] @ A_diagonal_blocks[n_i + 1].T - + temp_B_13[:, :] @ A_upper_arrow_blocks[n_i + 1].T - + temp_B_14[:, :] @ A_upper_buffer_blocks[n_i].T - ) + temp_B_12 @ A_diagonal_blocks[n_i + 1].conj().T + + temp_B_13 @ A_upper_arrow_blocks[n_i + 1].conj().T + + temp_B_14 @ A_upper_buffer_blocks[n_i].conj().T + ) # C: 3xMM(bbb) + MM(bab) ) + B_upper_buffer_blocks[n_i - 1] = ( -A_diagonal_blocks[n_i] @ ( A_upper_diagonal_blocks[n_i] @ B_upper_buffer_blocks[n_i] + A_upper_arrow_blocks[n_i] @ B_lower_arrow_blocks[0] + A_upper_buffer_blocks[n_i - 1] @ B_diagonal_blocks[0] - ) - - B_diagonal_blocks[n_i] @ C2[:, :].T + ) # C: 3xMM(bbb) + MM(bab) + - B_diagonal_blocks[n_i] @ C2.conj().T # C: MM(bbb) + A_diagonal_blocks[n_i] @ ( - temp_B_12[:, :] @ A_lower_buffer_blocks[n_i].T - + temp_B_13[:, :] @ A_upper_arrow_blocks[0].T - + temp_B_14[:, :] @ A_diagonal_blocks[0].T - ) + temp_B_12 @ A_lower_buffer_blocks[n_i].conj().T + + temp_B_13 @ A_upper_arrow_blocks[0].conj().T + + temp_B_14 @ A_diagonal_blocks[0].conj().T + ) # C: 3xMM(bbb) + MM(bab) ) + B_upper_arrow_blocks[n_i] = ( -A_diagonal_blocks[n_i] @ ( A_upper_diagonal_blocks[n_i] @ B_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block[:, :] + + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block + A_upper_buffer_blocks[n_i - 1] @ B_upper_arrow_blocks[0] - ) - - B_diagonal_blocks[n_i] @ C3[:, :].T + ) # C: 3xMM(bba) + MM(baa) + - B_diagonal_blocks[n_i] @ C3.conj().T # C: MM(bba) + A_diagonal_blocks[n_i] @ ( - temp_B_12[:, :] @ A_lower_arrow_blocks[n_i + 1].T - + temp_B_13[:, :] @ A_arrow_tip_block[:, :].T - + temp_B_14[:, :] @ A_lower_arrow_blocks[0].T - ) + temp_B_12 @ A_lower_arrow_blocks[n_i + 1].conj().T + + temp_B_13 @ A_arrow_tip_block.conj().T + + temp_B_14 @ A_lower_arrow_blocks[0].conj().T + ) # C: 3xMM(bba) + MM(baa) ) B_lower_diagonal_blocks[n_i] = ( -( - B_diagonal_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].T - + B_upper_arrow_blocks[n_i + 1] @ A_upper_arrow_blocks[n_i].T - + B_upper_buffer_blocks[n_i] @ A_upper_buffer_blocks[n_i - 1].T + B_diagonal_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].conj().T + + B_upper_arrow_blocks[n_i + 1] @ A_upper_arrow_blocks[n_i].conj().T + + B_upper_buffer_blocks[n_i] @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T - - C1[:, :] @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(bbb) + MM(bab) + - C1 @ B_diagonal_blocks[n_i] # C: MM(bbb) + ( - A_diagonal_blocks[n_i + 1] @ temp_B_21[:, :] - + A_upper_arrow_blocks[n_i + 1] @ temp_B_31[:, :] - + A_upper_buffer_blocks[n_i] @ temp_B_41[:, :] + A_diagonal_blocks[n_i + 1] @ temp_B_21 + + A_upper_arrow_blocks[n_i + 1] @ temp_B_31 + + A_upper_buffer_blocks[n_i] @ temp_B_41 ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(bbb) + MM(bab) ) + B_lower_buffer_blocks[n_i - 1] = ( -( - B_lower_buffer_blocks[n_i] @ A_upper_diagonal_blocks[n_i].T - + B_diagonal_blocks[0] @ A_upper_buffer_blocks[n_i - 1].T - + B_upper_arrow_blocks[0] @ A_upper_arrow_blocks[n_i].T + B_lower_buffer_blocks[n_i] @ A_upper_diagonal_blocks[n_i].conj().T + + B_diagonal_blocks[0] @ A_upper_buffer_blocks[n_i - 1].conj().T + + B_upper_arrow_blocks[0] @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T - - C2[:, :] @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(bbb) + MM(bab) + - C2 @ B_diagonal_blocks[n_i] # C: MM(bbb) + ( - A_lower_buffer_blocks[n_i] @ temp_B_21[:, :] - + A_upper_arrow_blocks[0] @ temp_B_31[:, :] - + A_diagonal_blocks[0] @ temp_B_41[:, :] + A_lower_buffer_blocks[n_i] @ temp_B_21 + + A_upper_arrow_blocks[0] @ temp_B_31 + + A_diagonal_blocks[0] @ temp_B_41 ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(bbb) + MM(bab) ) + B_lower_arrow_blocks[n_i] = ( -( - B_lower_arrow_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].T - + B_arrow_tip_block[:, :] @ A_upper_arrow_blocks[n_i].T - + B_lower_arrow_blocks[0] @ A_upper_buffer_blocks[n_i - 1].T + B_lower_arrow_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].conj().T + + B_arrow_tip_block @ A_upper_arrow_blocks[n_i].conj().T + + B_lower_arrow_blocks[0] @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T - - C3[:, :] @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(abb) + MM(aab) + - C3 @ B_diagonal_blocks[n_i] # C: MM(abb) + ( - A_lower_arrow_blocks[n_i + 1] @ temp_B_21[:, :] - + A_arrow_tip_block[:, :] @ temp_B_31[:, :] - + A_lower_arrow_blocks[0] @ temp_B_41[:, :] + A_lower_arrow_blocks[n_i + 1] @ temp_B_21 + + A_arrow_tip_block @ temp_B_31 + + A_lower_arrow_blocks[0] @ temp_B_41 ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 3xMM(abb) + MM(aab) ) B_diagonal_blocks[n_i] = ( @@ -762,81 +795,85 @@ def _ddbtasci_quadratic_permuted( + A_upper_arrow_blocks[n_i] @ B_lower_arrow_blocks[n_i + 1] + A_upper_buffer_blocks[n_i - 1] @ B_lower_buffer_blocks[n_i] ) - @ A_upper_diagonal_blocks[n_i].T + @ A_upper_diagonal_blocks[n_i].conj().T + ( A_upper_diagonal_blocks[n_i] @ B_upper_buffer_blocks[n_i] + A_upper_buffer_blocks[n_i - 1] @ B_diagonal_blocks[0] + A_upper_arrow_blocks[n_i] @ B_lower_arrow_blocks[0] ) - @ A_upper_buffer_blocks[n_i - 1].T + @ A_upper_buffer_blocks[n_i - 1].conj().T + ( A_upper_diagonal_blocks[n_i] @ B_upper_arrow_blocks[n_i + 1] - + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block[:, :] + + A_upper_arrow_blocks[n_i] @ B_arrow_tip_block + A_upper_buffer_blocks[n_i - 1] @ B_upper_arrow_blocks[0] ) - @ A_upper_arrow_blocks[n_i].T + @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i] + .conj() + .T # C: 8xMM(bbb) + 3xMM(bab) + 2xMM(bba) + MM(baa) + A_diagonal_blocks[n_i] @ ( - B1[:, :] @ A_lower_diagonal_blocks[n_i] - + B2[:, :] @ A_lower_buffer_blocks[n_i - 1] - + B3[:, :] @ A_lower_arrow_blocks[n_i] + B1 @ A_lower_diagonal_blocks[n_i] + + B2 @ A_lower_buffer_blocks[n_i - 1] + + B3 @ A_lower_arrow_blocks[n_i] ) - @ B_diagonal_blocks[n_i] - + B_diagonal_blocks[n_i].T + @ B_diagonal_blocks[n_i] # C: 4xMM(bbb) + MM(bab) + + B_diagonal_blocks[n_i] @ ( - C1[:, :].T @ A_upper_diagonal_blocks[n_i].T - + C2[:, :].T @ A_upper_buffer_blocks[n_i - 1].T - + C3[:, :].T @ A_upper_arrow_blocks[n_i].T + C1.conj().T @ A_upper_diagonal_blocks[n_i].conj().T + + C2.conj().T @ A_upper_buffer_blocks[n_i - 1].conj().T + + C3.conj().T @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T # C: 4xMM(bbb) + MM(bab) - A_diagonal_blocks[n_i] - @ (B1[:, :] @ temp_B_21 + B2[:, :] @ temp_B_41 + B3[:, :] @ temp_B_31) - @ A_diagonal_blocks[n_i].T + @ (B1 @ temp_B_21 + B2 @ temp_B_41 + B3 @ temp_B_31) + @ A_diagonal_blocks[n_i].conj().T # C: 4xMM(bbb) + MM(bab) - A_diagonal_blocks[n_i] @ ( ( - temp_B_12[:, :] @ A_diagonal_blocks[n_i + 1].T - + temp_B_13[:, :] @ A_upper_arrow_blocks[n_i + 1].T - + temp_B_14[:, :] @ A_upper_buffer_blocks[n_i].T + temp_B_12 @ A_diagonal_blocks[n_i + 1].conj().T + + temp_B_13 @ A_upper_arrow_blocks[n_i + 1].conj().T + + temp_B_14 @ A_upper_buffer_blocks[n_i].conj().T ) - @ A_upper_diagonal_blocks[n_i].T + @ A_upper_diagonal_blocks[n_i].conj().T + ( - temp_B_12[:, :] @ A_lower_buffer_blocks[n_i].T - + temp_B_14[:, :] @ A_diagonal_blocks[0].T - + temp_B_13[:, :] @ A_upper_arrow_blocks[0].T + temp_B_12 @ A_lower_buffer_blocks[n_i].conj().T + + temp_B_14 @ A_diagonal_blocks[0].conj().T + + temp_B_13 @ A_upper_arrow_blocks[0].conj().T ) - @ A_upper_buffer_blocks[n_i - 1].T + @ A_upper_buffer_blocks[n_i - 1].conj().T + ( - temp_B_12[:, :] @ A_lower_arrow_blocks[n_i + 1].T - + temp_B_13[:, :] @ A_arrow_tip_block[:, :].T - + temp_B_14[:, :] @ A_lower_arrow_blocks[0].T + temp_B_12 @ A_lower_arrow_blocks[n_i + 1].conj().T + + temp_B_13 @ A_arrow_tip_block.conj().T + + temp_B_14 @ A_lower_arrow_blocks[0].conj().T ) - @ A_upper_arrow_blocks[n_i].T + @ A_upper_arrow_blocks[n_i].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i] + .conj() + .T # C: 8xMM(bbb) + 3xMM(bab) + 2xMM(bba) + MM(baa) ) # --- Xr --- - A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1[:, :] - A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2[:, :] - A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B3[:, :] + A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1 # C: MM(bbb) + A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2 # C: MM(bbb) + A_upper_arrow_blocks[n_i] = -A_diagonal_blocks[n_i] @ B3 # C: MM(bba) D1[:, :] = A_lower_diagonal_blocks[n_i] D2[:, :] = A_lower_buffer_blocks[n_i - 1] D3[:, :] = A_lower_arrow_blocks[n_i] - A_lower_diagonal_blocks[n_i] = -C1[:, :] @ A_diagonal_blocks[n_i] - A_lower_buffer_blocks[n_i - 1] = -C2[:, :] @ A_diagonal_blocks[n_i] - A_lower_arrow_blocks[n_i] = -C3[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i] = -C1 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_buffer_blocks[n_i - 1] = -C2 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_arrow_blocks[n_i] = -C3 @ A_diagonal_blocks[n_i] # C: MM(abb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] + A_diagonal_blocks[n_i] - @ (B1[:, :] @ D1[:, :] + B2[:, :] @ D2[:, :] + B3[:, :] @ D3[:, :]) + @ (B1 @ D1 + B2 @ D2 + B3 @ D3) @ A_diagonal_blocks[n_i] - ) + ) # C: 4xMM(bbb) + MM(bab) A_lower_diagonal_blocks[0] = A_upper_buffer_blocks[0] A_upper_diagonal_blocks[0] = A_lower_buffer_blocks[0] diff --git a/src/serinv/algs/ddbtsc.py b/src/serinv/algs/ddbtsc.py index cad76660..6a743fae 100644 --- a/src/serinv/algs/ddbtsc.py +++ b/src/serinv/algs/ddbtsc.py @@ -183,11 +183,20 @@ def _ddbtsc( A_upper_diagonal_blocks: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 2xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) for n_i in range(0, A_diagonal_blocks.shape[0] - 1): + # C: LU(b) + 2xTRSM(b) A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + # C: 2xMM(bbb) A_diagonal_blocks[n_i + 1] = ( A_diagonal_blocks[n_i + 1] - A_lower_diagonal_blocks[n_i] @@ -205,11 +214,20 @@ def _ddbtsc_upward( A_upper_diagonal_blocks: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 2xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) for n_i in range(A_diagonal_blocks.shape[0] - 1, 0, -1): + # C: LU(b) + 2xTRSM(b) A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + # C: 2xMM(bbb) A_diagonal_blocks[n_i - 1] = ( A_diagonal_blocks[n_i - 1] - A_upper_diagonal_blocks[n_i - 1] @@ -228,6 +246,13 @@ def _ddbtsc_permuted( A_lower_buffer_blocks: ArrayLike, A_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 6xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) A_lower_buffer_blocks[0] = A_upper_diagonal_blocks[0] @@ -235,36 +260,30 @@ def _ddbtsc_permuted( for n_i in range(1, A_diagonal_blocks.shape[0] - 1): # Inverse current diagonal block - A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + A_diagonal_blocks[n_i] = xp.linalg.inv( + A_diagonal_blocks[n_i] + ) # C: LU(b) + 2xTRSM(b) # Update next diagonal block + temp_1 = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_diagonal_blocks[n_i + 1] = ( - A_diagonal_blocks[n_i + 1] - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i + 1] - temp_1 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Update lower buffer block + temp_2 = A_lower_buffer_blocks[n_i - 1] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_lower_buffer_blocks[n_i] = ( - -A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + -temp_2 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Update upper buffer block A_upper_buffer_blocks[n_i] = ( - -A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + -temp_1 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # Update 0-block (first) A_diagonal_blocks[0] = ( - A_diagonal_blocks[0] - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] + A_diagonal_blocks[0] - temp_2 @ A_upper_buffer_blocks[n_i - 1] # C: MM(bbb) ) @@ -277,37 +296,49 @@ def _ddbtsc_quadratic( B_upper_diagonal_blocks: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 8xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) - temp_1 = xp.empty_like(A_diagonal_blocks[0]) - temp_2 = xp.empty_like(A_diagonal_blocks[0]) - for n_i in range(0, A_diagonal_blocks.shape[0] - 1): - A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + A_diagonal_blocks[n_i] = xp.linalg.inv( + A_diagonal_blocks[n_i] + ) # C: LU(b) + 2xTRSM(b) B_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i].T - ) + A_diagonal_blocks[n_i] + @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T + ) # C: 2xMM(bbb) - temp_1[:, :] = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] - temp_2[:, :] = A_diagonal_blocks[n_i].T @ A_lower_diagonal_blocks[n_i].T + temp_1 = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(bbb) + temp_2 = ( + temp_1.conj().T + ) # A_diagonal_blocks[n_i].conj().T @ A_lower_diagonal_blocks[n_i].conj().T A_diagonal_blocks[n_i + 1] = ( - A_diagonal_blocks[n_i + 1] - temp_1[:, :] @ A_upper_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i + 1] - temp_1 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) B_diagonal_blocks[n_i + 1] = ( B_diagonal_blocks[n_i + 1] + A_lower_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] - @ A_lower_diagonal_blocks[n_i].T - - B_lower_diagonal_blocks[n_i] @ temp_2[:, :] - - temp_1[:, :] @ B_upper_diagonal_blocks[n_i] - ) + @ A_lower_diagonal_blocks[n_i].conj().T + - B_lower_diagonal_blocks[n_i] @ temp_2 + - temp_1 @ B_upper_diagonal_blocks[n_i] + ) # C: 4xMM(bbb) if invert_last_block: A_diagonal_blocks[-1] = xp.linalg.inv(A_diagonal_blocks[-1]) B_diagonal_blocks[-1] = ( - A_diagonal_blocks[-1] @ B_diagonal_blocks[-1] @ A_diagonal_blocks[-1].T + A_diagonal_blocks[-1] + @ B_diagonal_blocks[-1] + @ A_diagonal_blocks[-1].conj().T ) @@ -320,37 +351,45 @@ def _ddbtsc_upward_quadratic( B_upper_diagonal_blocks: ArrayLike, invert_last_block: bool, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 8xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) - temp_1 = xp.empty_like(A_diagonal_blocks[0]) - temp_2 = xp.empty_like(A_diagonal_blocks[0]) - for n_i in range(A_diagonal_blocks.shape[0] - 1, 0, -1): - A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + A_diagonal_blocks[n_i] = xp.linalg.inv( + A_diagonal_blocks[n_i] + ) # C: LU(b) + 2xTRSM(b) B_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i].T - ) + A_diagonal_blocks[n_i] + @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T + ) # C: 2xMM(bbb) - temp_1[:, :] = A_upper_diagonal_blocks[n_i - 1] @ A_diagonal_blocks[n_i] - temp_2[:, :] = A_diagonal_blocks[n_i].T @ A_upper_diagonal_blocks[n_i - 1].T + temp_1 = A_upper_diagonal_blocks[n_i - 1] @ A_diagonal_blocks[n_i] # C: MM(bbb) + temp_2 = temp_1.conj().T A_diagonal_blocks[n_i - 1] = ( - A_diagonal_blocks[n_i - 1] - temp_1[:, :] @ A_lower_diagonal_blocks[n_i - 1] - ) + A_diagonal_blocks[n_i - 1] - temp_1 @ A_lower_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) B_diagonal_blocks[n_i - 1] = ( B_diagonal_blocks[n_i - 1] + A_upper_diagonal_blocks[n_i - 1] @ B_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i - 1].T - - B_upper_diagonal_blocks[n_i - 1] @ temp_2[:, :] - - temp_1[:, :] @ B_lower_diagonal_blocks[n_i - 1] - ) + @ A_upper_diagonal_blocks[n_i - 1].conj().T + - B_upper_diagonal_blocks[n_i - 1] @ temp_2 + - temp_1 @ B_lower_diagonal_blocks[n_i - 1] + ) # C: 4xMM(bbb) if invert_last_block: A_diagonal_blocks[0] = xp.linalg.inv(A_diagonal_blocks[0]) B_diagonal_blocks[0] = ( - A_diagonal_blocks[0] @ B_diagonal_blocks[0] @ A_diagonal_blocks[0].T + A_diagonal_blocks[0] @ B_diagonal_blocks[0] @ A_diagonal_blocks[0].conj().T ) @@ -366,6 +405,13 @@ def _ddbtsc_quadratic_permuted( B_lower_buffer_blocks: ArrayLike, B_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 1xLU(b) + 2xTRSM(b) + 22xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) A_lower_buffer_blocks[0] = A_upper_diagonal_blocks[0] @@ -376,91 +422,67 @@ def _ddbtsc_quadratic_permuted( for n_i in range(1, A_diagonal_blocks.shape[0] - 1): # Inverse current diagonal block - A_diagonal_blocks[n_i] = xp.linalg.inv(A_diagonal_blocks[n_i]) + A_diagonal_blocks[n_i] = xp.linalg.inv( + A_diagonal_blocks[n_i] + ) # C: LU(b) + 2xTRSM(b) # Update next diagonal block + temp_1 = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_diagonal_blocks[n_i + 1] = ( - A_diagonal_blocks[n_i + 1] - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i + 1] - temp_1 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Update lower buffer block + temp_2 = A_lower_buffer_blocks[n_i - 1] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_lower_buffer_blocks[n_i] = ( - -A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_diagonal_blocks[n_i] - ) + -temp_2 @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) # Update upper buffer block A_upper_buffer_blocks[n_i] = ( - -A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + -temp_1 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # Update 0-block (first) A_diagonal_blocks[0] = ( - A_diagonal_blocks[0] - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ A_upper_buffer_blocks[n_i - 1] - ) + A_diagonal_blocks[0] - temp_2 @ A_upper_buffer_blocks[n_i - 1] + ) # C: MM(bbb) # --- Xl --- B_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i].T - ) + A_diagonal_blocks[n_i] + @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T + ) # C: 2xMM(bbb) + temp_1_conjt = temp_1.conj().T + temp_3 = A_lower_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i] # C: MM(bbb) B_diagonal_blocks[n_i + 1] = ( B_diagonal_blocks[n_i + 1] - + A_lower_diagonal_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_diagonal_blocks[n_i].T - - B_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_diagonal_blocks[n_i].T - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_diagonal_blocks[n_i] - ) + + temp_3 @ A_lower_diagonal_blocks[n_i].conj().T + - B_lower_diagonal_blocks[n_i] @ temp_1_conjt + - temp_1 @ B_upper_diagonal_blocks[n_i] + ) # C: 3xMM(bbb) + temp_2_conjt = temp_2.conj().T B_upper_buffer_blocks[n_i] = ( B_upper_buffer_blocks[n_i] - + A_lower_diagonal_blocks[n_i] - @ B_diagonal_blocks[n_i] - @ A_lower_buffer_blocks[n_i - 1].T - - B_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i].T - @ A_lower_buffer_blocks[n_i - 1].T - - A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - @ B_upper_buffer_blocks[n_i - 1] - ) + + temp_3 @ A_lower_buffer_blocks[n_i - 1].conj().T + - B_lower_diagonal_blocks[n_i] @ temp_2_conjt + - temp_1 @ B_upper_buffer_blocks[n_i - 1] + ) # C: 3xMM(bbb) + temp_4 = A_lower_buffer_blocks[n_i - 1] @ B_diagonal_blocks[n_i] # C: MM(bbb) B_lower_buffer_blocks[n_i] = ( B_lower_buffer_blocks[n_i] - + A_lower_buffer_blocks[n_i - 1] - @ B_diagonal_blocks[n_i] - @ A_lower_diagonal_blocks[n_i].T - - B_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i].T - @ A_lower_diagonal_blocks[n_i].T - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ B_upper_diagonal_blocks[n_i] - ) + + temp_4 @ A_lower_diagonal_blocks[n_i].conj().T + - B_lower_buffer_blocks[n_i - 1] @ temp_1_conjt + - temp_2 @ B_upper_diagonal_blocks[n_i] + ) # C: 3xMM(bbb) B_diagonal_blocks[0] = ( B_diagonal_blocks[0] - + A_lower_buffer_blocks[n_i - 1] - @ B_diagonal_blocks[n_i] - @ A_lower_buffer_blocks[n_i - 1].T - - B_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i].T - @ A_lower_buffer_blocks[n_i - 1].T - - A_lower_buffer_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - @ B_upper_buffer_blocks[n_i - 1] - ) + + temp_4 @ A_lower_buffer_blocks[n_i - 1].conj().T + - B_lower_buffer_blocks[n_i - 1] @ temp_2_conjt + - temp_2 @ B_upper_buffer_blocks[n_i - 1] + ) # C: 3xMM(bbb) diff --git a/src/serinv/algs/ddbtsci.py b/src/serinv/algs/ddbtsci.py index 909db341..a44cef0d 100644 --- a/src/serinv/algs/ddbtsci.py +++ b/src/serinv/algs/ddbtsci.py @@ -178,31 +178,27 @@ def _ddbtsci( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, ): - xp, _ = _get_module_from_array(A_diagonal_blocks) - - if A_diagonal_blocks.shape[0] > 1: - # If there is only a single diagonal block, we don't need these buffers. - temp_lower = xp.empty_like(A_lower_diagonal_blocks[0]) - + """ + Operations Counts: + ------------------ + 5xMM(bbb) + """ for n_i in range(A_diagonal_blocks.shape[0] - 2, -1, -1): - temp_lower[:, :] = A_lower_diagonal_blocks[n_i] + temp_lower = A_lower_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i] # C: MM(bbb) A_lower_diagonal_blocks[n_i] = ( - -A_diagonal_blocks[n_i + 1] - @ A_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i] - ) + -A_diagonal_blocks[n_i + 1] @ temp_lower + ) # C: MM(bbb) A_upper_diagonal_blocks[n_i] = ( -A_diagonal_blocks[n_i] @ A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] - ) + ) # C: 2xMM(bbb) A_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] - - A_upper_diagonal_blocks[n_i] @ temp_lower @ A_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i] - A_upper_diagonal_blocks[n_i] @ temp_lower + ) # C: MM(bbb) def _ddbtsci_upward( @@ -210,31 +206,29 @@ def _ddbtsci_upward( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, ): - xp, _ = _get_module_from_array(A_diagonal_blocks) - - if A_diagonal_blocks.shape[0] > 1: - # If there is only a single diagonal block, we don't need these buffers. - temp_upper = xp.empty_like(A_upper_diagonal_blocks[0]) - + """ + Operations Counts: + ------------------ + 5xMM(bbb) + """ for n_i in range(1, A_diagonal_blocks.shape[0]): - temp_upper[:, :] = A_upper_diagonal_blocks[n_i - 1] + temp_upper = ( + A_upper_diagonal_blocks[n_i - 1] @ A_diagonal_blocks[n_i] + ) # C: MM(bbb) A_lower_diagonal_blocks[n_i - 1] = ( -A_diagonal_blocks[n_i] @ A_lower_diagonal_blocks[n_i - 1] @ A_diagonal_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) A_upper_diagonal_blocks[n_i - 1] = ( - -A_diagonal_blocks[n_i - 1] - @ A_upper_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i] - ) + -A_diagonal_blocks[n_i - 1] @ temp_upper + ) # C: MM(bbb) A_diagonal_blocks[n_i] = ( - A_diagonal_blocks[n_i] - - A_lower_diagonal_blocks[n_i - 1] @ temp_upper @ A_diagonal_blocks[n_i] - ) + A_diagonal_blocks[n_i] - A_lower_diagonal_blocks[n_i - 1] @ temp_upper + ) # C: MM(bbb) def _ddbtsci_permuted( @@ -244,6 +238,11 @@ def _ddbtsci_permuted( A_lower_buffer_blocks: ArrayLike, A_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 16xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) B1 = xp.empty_like(A_lower_diagonal_blocks[0]) @@ -259,38 +258,36 @@ def _ddbtsci_permuted( B1[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] + A_upper_buffer_blocks[n_i - 1] @ A_lower_buffer_blocks[n_i] - ) + ) # C: 2xMM(bbb) B2[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_buffer_blocks[n_i] + A_upper_buffer_blocks[n_i - 1] @ A_diagonal_blocks[0] - ) + ) # C: 2xMM(bbb) C1[:, :] = ( A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + A_upper_buffer_blocks[n_i] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) C2[:, :] = ( A_lower_buffer_blocks[n_i] @ A_lower_diagonal_blocks[n_i] + A_diagonal_blocks[0] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) - A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1[:, :] - A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2[:, :] + A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1 # C: MM(bbb) + A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2 # C: MM(bbb) D1[:, :] = A_lower_diagonal_blocks[n_i] D2[:, :] = A_lower_buffer_blocks[n_i - 1] - A_lower_diagonal_blocks[n_i] = -C1[:, :] @ A_diagonal_blocks[n_i] - A_lower_buffer_blocks[n_i - 1] = -C2[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i] = -C1 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_buffer_blocks[n_i - 1] = -C2 @ A_diagonal_blocks[n_i] # C: MM(bbb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] - + A_diagonal_blocks[n_i] - @ (B1[:, :] @ D1[:, :] + B2[:, :] @ D2[:, :]) - @ A_diagonal_blocks[n_i] - ) + + A_diagonal_blocks[n_i] @ (B1 @ D1 + B2 @ D2) @ A_diagonal_blocks[n_i] + ) # C: 4xMM(bbb) A_lower_diagonal_blocks[0] = A_upper_buffer_blocks[0] A_upper_diagonal_blocks[0] = A_lower_buffer_blocks[0] @@ -304,6 +301,11 @@ def _ddbtsci_quadratic( B_lower_diagonal_blocks: ArrayLike, B_upper_diagonal_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 19xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) if A_diagonal_blocks.shape[0] > 1: @@ -315,58 +317,59 @@ def _ddbtsci_quadratic( temp_1 = xp.empty_like(A_diagonal_blocks[0]) temp_2 = xp.empty_like(A_diagonal_blocks[0]) temp_3 = xp.empty_like(A_diagonal_blocks[0]) - temp_4 = xp.empty_like(A_diagonal_blocks[0]) + temp_3_b = xp.empty_like(A_diagonal_blocks[0]) for n_i in range(A_diagonal_blocks.shape[0] - 2, -1, -1): - temp_upper_lesser[:, :] = B_upper_diagonal_blocks[n_i] - temp_1[:, :] = A_diagonal_blocks[n_i] @ A_upper_diagonal_blocks[n_i] - temp_4[:, :] = A_lower_diagonal_blocks[n_i].T @ A_diagonal_blocks[n_i + 1].T + temp_1[:, :] = ( + A_diagonal_blocks[n_i] @ A_upper_diagonal_blocks[n_i] + ) # C: MM(bbb) + temp_2[:, :] = temp_1.conj().T + temp_2_b = B_diagonal_blocks[n_i + 1] @ temp_2 # C: MM(bbb) + temp_3[:, :] = ( + A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + ) # C: MM(bbb) + temp_3_b[:, :] = B_diagonal_blocks[n_i] @ temp_3.conj().T # C: MM(bbb) + temp_4 = temp_3 @ B_diagonal_blocks[n_i] # C: MM(bbb) - B_upper_diagonal_blocks[n_i] = ( - -temp_1[:, :] @ B_diagonal_blocks[n_i + 1] - - B_diagonal_blocks[n_i] @ temp_4[:, :] - + A_diagonal_blocks[n_i] + temp_upper_lesser[:, :] = ( + A_diagonal_blocks[n_i] @ B_upper_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i + 1].T - ) + @ A_diagonal_blocks[n_i + 1].conj().T + ) # C: 2xMM(bbb) - temp_lower_lesser[:, :] = B_lower_diagonal_blocks[n_i] - temp_2[:, :] = A_upper_diagonal_blocks[n_i].T @ A_diagonal_blocks[n_i].T - temp_3[:, :] = A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + B_upper_diagonal_blocks[n_i] = ( + -temp_1 @ B_diagonal_blocks[n_i + 1] - temp_3_b + temp_upper_lesser + ) # C: MM(bbb) - B_lower_diagonal_blocks[n_i] = ( - -B_diagonal_blocks[n_i + 1] @ temp_2[:, :] - - temp_3[:, :] @ B_diagonal_blocks[n_i] - + A_diagonal_blocks[n_i + 1] + temp_lower_lesser[:, :] = ( + A_diagonal_blocks[n_i + 1] @ B_lower_diagonal_blocks[n_i] - @ A_diagonal_blocks[n_i].T - ) + @ A_diagonal_blocks[n_i].conj().T + ) # C: 2xMM(bbb) + + B_lower_diagonal_blocks[n_i] = -temp_2_b - temp_4 + temp_lower_lesser B_diagonal_blocks[n_i] = ( B_diagonal_blocks[n_i] - + temp_1[:, :] @ B_diagonal_blocks[n_i + 1] @ temp_2[:, :] - + temp_1[:, :] @ temp_3[:, :] @ B_diagonal_blocks[n_i] - + B_diagonal_blocks[n_i].T @ temp_4[:, :] @ temp_2[:, :] - - temp_1[:, :] - @ A_diagonal_blocks[n_i + 1] - @ temp_lower_lesser[:, :] - @ A_diagonal_blocks[n_i].T - - A_diagonal_blocks[n_i] - @ temp_upper_lesser[:, :] - @ A_diagonal_blocks[n_i + 1].T - @ temp_2[:, :] - ) + + temp_1 @ temp_2_b + + temp_1 @ temp_4 + + temp_3_b @ temp_2 + - temp_1 @ temp_lower_lesser + - temp_upper_lesser @ temp_2 + ) # C: 5xMM(bbb) temp_lower_retarded[:, :] = A_lower_diagonal_blocks[n_i] - A_lower_diagonal_blocks[n_i] = -temp_3[:, :] @ A_diagonal_blocks[n_i] - A_upper_diagonal_blocks[n_i] = -temp_1[:, :] @ A_diagonal_blocks[n_i + 1] + A_lower_diagonal_blocks[n_i] = -temp_3 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_upper_diagonal_blocks[n_i] = ( + -temp_1 @ A_diagonal_blocks[n_i + 1] + ) # C: MM(bbb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] - A_upper_diagonal_blocks[n_i] - @ temp_lower_retarded[:, :] + @ temp_lower_retarded @ A_diagonal_blocks[n_i] - ) + ) # C: 2xMM(bbb) def _ddbtsci_upward_quadratic( @@ -377,6 +380,11 @@ def _ddbtsci_upward_quadratic( B_lower_diagonal_blocks: ArrayLike, B_upper_diagonal_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 19xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) if A_diagonal_blocks.shape[0] > 1: @@ -391,56 +399,59 @@ def _ddbtsci_upward_quadratic( temp_4 = xp.empty_like(A_diagonal_blocks[0]) for n_i in range(1, A_diagonal_blocks.shape[0]): - - temp_upper_lesser[:, :] = B_upper_diagonal_blocks[n_i - 1] - temp_1[:, :] = A_diagonal_blocks[n_i - 1] @ A_upper_diagonal_blocks[n_i - 1] - temp_4[:, :] = A_lower_diagonal_blocks[n_i - 1].T @ A_diagonal_blocks[n_i].T - - B_upper_diagonal_blocks[n_i - 1] = ( - -temp_1[:, :] @ B_diagonal_blocks[n_i] - - B_diagonal_blocks[n_i - 1] @ temp_4[:, :] - + A_diagonal_blocks[n_i - 1] + temp_1[:, :] = ( + A_diagonal_blocks[n_i - 1] @ A_upper_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) + temp_1_b = temp_1 @ B_diagonal_blocks[n_i] # C: MM(bbb) + temp_2[:, :] = temp_1.conj().T + temp_2_b = B_diagonal_blocks[n_i] @ temp_2 # C: MM(bbb) + temp_3[:, :] = ( + A_diagonal_blocks[n_i] @ A_lower_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) + temp_4[:, :] = temp_3.conj().T + temp_4_b = B_diagonal_blocks[n_i - 1] @ temp_4 # C: MM(bbb) + + temp_upper_lesser[:, :] = ( + A_diagonal_blocks[n_i - 1] @ B_upper_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i].T - ) + @ A_diagonal_blocks[n_i].conj().T + ) # C: 2xMM(bbb) - temp_lower_lesser[:, :] = B_lower_diagonal_blocks[n_i - 1] - temp_2[:, :] = A_upper_diagonal_blocks[n_i - 1].T @ A_diagonal_blocks[n_i - 1].T - temp_3[:, :] = A_diagonal_blocks[n_i] @ A_lower_diagonal_blocks[n_i - 1] + B_upper_diagonal_blocks[n_i - 1] = -temp_1_b - temp_4_b + temp_upper_lesser - B_lower_diagonal_blocks[n_i - 1] = ( - -B_diagonal_blocks[n_i] @ temp_2[:, :] - - temp_3[:, :] @ B_diagonal_blocks[n_i - 1] - + A_diagonal_blocks[n_i] + temp_lower_lesser[:, :] = ( + A_diagonal_blocks[n_i] @ B_lower_diagonal_blocks[n_i - 1] - @ A_diagonal_blocks[n_i - 1].T - ) + @ A_diagonal_blocks[n_i - 1].conj().T + ) # C: 2xMM(bbb) + + B_lower_diagonal_blocks[n_i - 1] = ( + -temp_2_b - temp_3 @ B_diagonal_blocks[n_i - 1] + temp_lower_lesser + ) # C: MM(bbb) B_diagonal_blocks[n_i] = ( B_diagonal_blocks[n_i] - + temp_3[:, :] @ B_diagonal_blocks[n_i - 1] @ temp_4[:, :] - + temp_3[:, :] @ temp_1[:, :] @ B_diagonal_blocks[n_i] - + B_diagonal_blocks[n_i].T @ temp_2[:, :] @ temp_4[:, :] - - temp_3[:, :] - @ A_diagonal_blocks[n_i - 1] - @ temp_upper_lesser[:, :] - @ A_diagonal_blocks[n_i].T - - A_diagonal_blocks[n_i] - @ temp_lower_lesser[:, :] - @ A_diagonal_blocks[n_i - 1].T - @ temp_4[:, :] - ) + + temp_3 @ temp_4_b + + temp_3 @ temp_1_b + + temp_2_b @ temp_4 + - temp_3 @ temp_upper_lesser + - temp_lower_lesser @ temp_4 + ) # C: 5xMM(bbb) temp_upper_retarded[:, :] = A_upper_diagonal_blocks[n_i - 1] - A_lower_diagonal_blocks[n_i - 1] = -temp_3[:, :] @ A_diagonal_blocks[n_i - 1] - A_upper_diagonal_blocks[n_i - 1] = -temp_1[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i - 1] = ( + -temp_3 @ A_diagonal_blocks[n_i - 1] + ) # C: MM(bbb) + A_upper_diagonal_blocks[n_i - 1] = ( + -temp_1 @ A_diagonal_blocks[n_i] + ) # C: MM(bbb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] - A_lower_diagonal_blocks[n_i - 1] - @ temp_upper_retarded[:, :] + @ temp_upper_retarded @ A_diagonal_blocks[n_i] - ) + ) # C: 2xMM(bbb) def _ddbtsci_quadratic_permuted( @@ -455,6 +466,11 @@ def _ddbtsci_quadratic_permuted( B_lower_buffer_blocks: ArrayLike, B_upper_buffer_blocks: ArrayLike, ): + """ + Operations Counts: + ------------------ + 66xMM(bbb) + """ xp, _ = _get_module_from_array(A_diagonal_blocks) B1 = xp.empty_like(A_lower_diagonal_blocks[0]) @@ -475,162 +491,135 @@ def _ddbtsci_quadratic_permuted( B1[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1] + A_upper_buffer_blocks[n_i - 1] @ A_lower_buffer_blocks[n_i] - ) + ) # C: 2xMM(bbb) B2[:, :] = ( A_upper_diagonal_blocks[n_i] @ A_upper_buffer_blocks[n_i] + A_upper_buffer_blocks[n_i - 1] @ A_diagonal_blocks[0] - ) + ) # C: 2xMM(bbb) C1[:, :] = ( A_diagonal_blocks[n_i + 1] @ A_lower_diagonal_blocks[n_i] + A_upper_buffer_blocks[n_i] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) C2[:, :] = ( A_lower_buffer_blocks[n_i] @ A_lower_diagonal_blocks[n_i] + A_diagonal_blocks[0] @ A_lower_buffer_blocks[n_i - 1] - ) + ) # C: 2xMM(bbb) # --- Xl --- temp_B_12[:, :] = B_upper_diagonal_blocks[n_i] + temp_B_12_buffer = temp_B_12 @ A_lower_buffer_blocks[n_i].conj().T # C: MM(bbb) temp_B_13[:, :] = B_upper_buffer_blocks[n_i - 1] temp_B_21[:, :] = B_lower_diagonal_blocks[n_i] temp_B_31[:, :] = B_lower_buffer_blocks[n_i - 1] + D1 = ( + A_upper_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i + 1] + + A_upper_buffer_blocks[n_i - 1] @ B_lower_buffer_blocks[n_i] + ) # C: 2xMM(bbb) + + D2 = ( + A_upper_diagonal_blocks[n_i] @ B_upper_buffer_blocks[n_i] + + A_upper_buffer_blocks[n_i - 1] @ B_diagonal_blocks[0] + ) # C: 2xMM(bbb) + B_upper_diagonal_blocks[n_i] = ( - -A_diagonal_blocks[n_i] - @ ( - A_upper_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i + 1] - + A_upper_buffer_blocks[n_i - 1] @ B_lower_buffer_blocks[n_i] - ) - - B_diagonal_blocks[n_i] - @ ( - A_lower_diagonal_blocks[n_i].T @ A_diagonal_blocks[n_i + 1].T - + A_lower_buffer_blocks[n_i - 1].T @ A_upper_buffer_blocks[n_i].T - ) + -A_diagonal_blocks[n_i] @ D1 + - B_diagonal_blocks[n_i] @ C1.conj().T + A_diagonal_blocks[n_i] @ ( - B_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1].T - + B_upper_buffer_blocks[n_i - 1] @ A_upper_buffer_blocks[n_i].T + B_upper_diagonal_blocks[n_i] @ A_diagonal_blocks[n_i + 1].conj().T + + B_upper_buffer_blocks[n_i - 1] @ A_upper_buffer_blocks[n_i].conj().T ) - ) + ) # C: 5xMM(bbb) + B_upper_buffer_blocks[n_i - 1] = ( - -A_diagonal_blocks[n_i] - @ ( - A_upper_diagonal_blocks[n_i] @ B_upper_buffer_blocks[n_i] - + A_upper_buffer_blocks[n_i - 1] @ B_diagonal_blocks[0] - ) - - B_diagonal_blocks[n_i] - @ ( - A_lower_diagonal_blocks[n_i].T @ A_lower_buffer_blocks[n_i].T - + A_lower_buffer_blocks[n_i - 1].T @ A_diagonal_blocks[0].T - ) + -A_diagonal_blocks[n_i] @ D2 + - B_diagonal_blocks[n_i] @ C2.conj().T + A_diagonal_blocks[n_i] @ ( - temp_B_12[:, :] @ A_lower_buffer_blocks[n_i].T - + B_upper_buffer_blocks[n_i - 1] @ A_diagonal_blocks[0].T + temp_B_12_buffer + + B_upper_buffer_blocks[n_i - 1] @ A_diagonal_blocks[0].conj().T ) - ) + ) # C: 4xMM(bbb) B_lower_diagonal_blocks[n_i] = ( -( - B_diagonal_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].T - + B_upper_buffer_blocks[n_i] @ A_upper_buffer_blocks[n_i - 1].T + B_diagonal_blocks[n_i + 1] @ A_upper_diagonal_blocks[n_i].conj().T + + B_upper_buffer_blocks[n_i] @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T - - (C1[:, :]) @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T + - C1 @ B_diagonal_blocks[n_i] + ( A_diagonal_blocks[n_i + 1] @ B_lower_diagonal_blocks[n_i] + A_upper_buffer_blocks[n_i] @ B_lower_buffer_blocks[n_i - 1] ) - @ A_diagonal_blocks[n_i].T - ) + @ A_diagonal_blocks[n_i].conj().T + ) # C: 7xMM(bbb) + B_lower_buffer_blocks[n_i - 1] = ( -( - B_lower_buffer_blocks[n_i] @ A_upper_diagonal_blocks[n_i].T - + B_diagonal_blocks[0] @ A_upper_buffer_blocks[n_i - 1].T + B_lower_buffer_blocks[n_i] @ A_upper_diagonal_blocks[n_i].conj().T + + B_diagonal_blocks[0] @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T - - (C2[:, :]) @ B_diagonal_blocks[n_i] + @ A_diagonal_blocks[n_i].conj().T + - C2 @ B_diagonal_blocks[n_i] + ( - A_lower_buffer_blocks[n_i] @ temp_B_21[:, :] + A_lower_buffer_blocks[n_i] @ temp_B_21 + A_diagonal_blocks[0] @ B_lower_buffer_blocks[n_i - 1] ) - @ A_diagonal_blocks[n_i].T - ) + @ A_diagonal_blocks[n_i].conj().T + ) # C: 7xMM(bbb) B_diagonal_blocks[n_i] = ( B_diagonal_blocks[n_i] + A_diagonal_blocks[n_i] @ ( - ( - A_upper_diagonal_blocks[n_i] @ B_diagonal_blocks[n_i + 1] - + A_upper_buffer_blocks[n_i - 1] @ B_lower_buffer_blocks[n_i] - ) - @ A_upper_diagonal_blocks[n_i].T - + ( - A_upper_diagonal_blocks[n_i] @ B_upper_buffer_blocks[n_i] - + A_upper_buffer_blocks[n_i - 1] @ B_diagonal_blocks[0] - ) - @ A_upper_buffer_blocks[n_i - 1].T + D1 @ A_upper_diagonal_blocks[n_i].conj().T + + D2 @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T + A_diagonal_blocks[n_i] - @ ( - (B1[:, :]) @ A_lower_diagonal_blocks[n_i] - + (B2[:, :]) @ A_lower_buffer_blocks[n_i - 1] - ) + @ (B1 @ A_lower_diagonal_blocks[n_i] + B2 @ A_lower_buffer_blocks[n_i - 1]) @ B_diagonal_blocks[n_i] - + B_diagonal_blocks[n_i].T + + B_diagonal_blocks[n_i] @ ( - ( - A_lower_diagonal_blocks[n_i].T @ A_diagonal_blocks[n_i + 1].T - + A_lower_buffer_blocks[n_i - 1].T @ A_upper_buffer_blocks[n_i].T - ) - @ A_upper_diagonal_blocks[n_i].T - + ( - A_lower_diagonal_blocks[n_i].T @ A_lower_buffer_blocks[n_i].T - + A_lower_buffer_blocks[n_i - 1].T @ A_diagonal_blocks[0].T - ) - @ A_upper_buffer_blocks[n_i - 1].T + C1.conj().T @ A_upper_diagonal_blocks[n_i].conj().T + + C2.conj().T @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T + @ A_diagonal_blocks[n_i].conj().T - A_diagonal_blocks[n_i] - @ ((B1[:, :]) @ temp_B_21 + (B2[:, :]) @ temp_B_31) - @ A_diagonal_blocks[n_i].T + @ (B1 @ temp_B_21 + B2 @ temp_B_31) + @ A_diagonal_blocks[n_i].conj().T - A_diagonal_blocks[n_i] @ ( ( - temp_B_12 @ A_diagonal_blocks[n_i + 1].T - + temp_B_13 @ A_upper_buffer_blocks[n_i].T - ) - @ A_upper_diagonal_blocks[n_i].T - + ( - temp_B_12 @ A_lower_buffer_blocks[n_i].T - + temp_B_13 @ A_diagonal_blocks[0].T + temp_B_12 @ A_diagonal_blocks[n_i + 1].conj().T + + temp_B_13 @ A_upper_buffer_blocks[n_i].conj().T ) - @ A_upper_buffer_blocks[n_i - 1].T + @ A_upper_diagonal_blocks[n_i].conj().T + + (temp_B_12_buffer + temp_B_13 @ A_diagonal_blocks[0].conj().T) + @ A_upper_buffer_blocks[n_i - 1].conj().T ) - @ A_diagonal_blocks[n_i].T - ) + @ A_diagonal_blocks[n_i].conj().T + ) # C: 23xMM(bbb) # --- Xr --- - A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1[:, :] - A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2[:, :] + A_upper_diagonal_blocks[n_i] = -A_diagonal_blocks[n_i] @ B1 # C: MM(bbb) + A_upper_buffer_blocks[n_i - 1] = -A_diagonal_blocks[n_i] @ B2 # C: MM(bbb) D1[:, :] = A_lower_diagonal_blocks[n_i] D2[:, :] = A_lower_buffer_blocks[n_i - 1] - A_lower_diagonal_blocks[n_i] = -C1[:, :] @ A_diagonal_blocks[n_i] - A_lower_buffer_blocks[n_i - 1] = -C2[:, :] @ A_diagonal_blocks[n_i] + A_lower_diagonal_blocks[n_i] = -C1 @ A_diagonal_blocks[n_i] # C: MM(bbb) + A_lower_buffer_blocks[n_i - 1] = -C2 @ A_diagonal_blocks[n_i] # C: MM(bbb) A_diagonal_blocks[n_i] = ( A_diagonal_blocks[n_i] - + A_diagonal_blocks[n_i] - @ (B1[:, :] @ D1[:, :] + B2[:, :] @ D2[:, :]) - @ A_diagonal_blocks[n_i] - ) + + A_diagonal_blocks[n_i] @ (B1 @ D1 + B2 @ D2) @ A_diagonal_blocks[n_i] + ) # C: 4xMM(bbb) A_lower_diagonal_blocks[0] = A_upper_buffer_blocks[0] A_upper_diagonal_blocks[0] = A_lower_buffer_blocks[0] diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index 46be5972..722582a6 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, @@ -30,7 +27,14 @@ def allocate_ddbtars( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -80,7 +84,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -144,7 +148,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -224,8 +228,15 @@ def map_ddbtasc_to_ddbtars( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str, + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +371,14 @@ def aggregate_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -455,7 +473,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -564,11 +582,12 @@ def aggregate_ddbtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -576,9 +595,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -586,9 +605,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -596,9 +615,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -606,9 +625,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_arrow_blocks_comm.data.ptr, count=count, @@ -616,17 +635,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -662,11 +682,12 @@ def aggregate_ddbtars( ddbtars["A_upper_arrow_blocks"] = _A_upper_arrow_blocks[1:] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -674,9 +695,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -684,9 +705,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -694,9 +715,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_arrow_blocks_comm.data.ptr, count=count, @@ -704,9 +725,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_arrow_blocks_comm.data.ptr, count=count, @@ -714,17 +735,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_B_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_arrow_tip_block_comm.data.ptr, recvbuf=_B_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -763,7 +785,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -787,8 +809,18 @@ def scatter_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() + _A_diagonal_blocks: ArrayLike = ddbtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = ddbtars.get("A_lower_diagonal_blocks", None) _A_upper_diagonal_blocks: ArrayLike = ddbtars.get("A_upper_diagonal_blocks", None) @@ -857,8 +889,15 @@ def map_ddbtars_to_ddbtasci( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ddbtrs.py b/src/serinv/wrappers/ddbtrs.py index 15393ae9..937d8def 100644 --- a/src/serinv/wrappers/ddbtrs.py +++ b/src/serinv/wrappers/ddbtrs.py @@ -24,7 +24,14 @@ def allocate_ddbtrs( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -57,7 +64,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -98,7 +105,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -150,8 +157,15 @@ def map_ddbtsc_to_ddbtrs( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -241,7 +255,14 @@ def aggregate_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -304,7 +325,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -379,11 +400,12 @@ def aggregate_ddbtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -391,9 +413,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -401,9 +423,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -411,6 +433,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -433,11 +456,12 @@ def aggregate_ddbtrs( ddbtrs["A_upper_diagonal_blocks"] = _A_upper_diagonal_blocks[1:-2] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -445,9 +469,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -455,9 +479,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -465,6 +489,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -492,7 +517,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -510,8 +535,15 @@ def scatter_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -565,8 +597,15 @@ def map_ddbtrs_to_ddbtsci( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index da2fcff4..27335195 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -22,6 +24,7 @@ def pddbtasc( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -115,9 +118,10 @@ def pddbtasc( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) - buffers: dict = kwargs.get("buffers", None) ddbtars: dict = kwargs.get("ddbtars", None) strategy: str = kwargs.get("strategy", "allgather") @@ -200,14 +204,23 @@ def pddbtasc( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic ddbtars["A_arrow_tip_block"][:] += A_arrow_tip_initial if quadratic: @@ -226,3 +239,5 @@ def pddbtasc( ) comm.Barrier() + + return elapsed diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f86b6a9c..f235d76b 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -22,6 +22,7 @@ def pddbtasci( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -159,6 +160,7 @@ def pddbtasci( ddbtars=ddbtars, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtars_to_ddbtasci( @@ -179,6 +181,7 @@ def pddbtasci( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index fc0a3765..e9f1eb9e 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -19,6 +21,7 @@ def pddbtsc( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -82,6 +85,7 @@ def pddbtsc( - _B_upper_diagonal_blocks : ArrayLike The upper diagonal blocks of the reduced system. """ + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -161,14 +165,23 @@ def pddbtsc( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Perform Schur complement on the reduced system ddbtsc( @@ -180,3 +193,5 @@ def pddbtsc( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 144054f8..63d94c84 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import ddbtsci @@ -19,6 +18,7 @@ def pddbtsci( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -89,8 +89,6 @@ def pddbtsci( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") - xp, _ = _get_module_from_array(arr=A_diagonal_blocks) - rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) buffers: dict = kwargs.get("buffers", None) @@ -125,6 +123,7 @@ def pddbtsci( ddbtrs=ddbtrs, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtrs_to_ddbtsci( @@ -139,6 +138,7 @@ def pddbtsci( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index 98cfdf46..ee6f26bb 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_pobtars( A_diagonal_blocks: ArrayLike, @@ -29,6 +26,7 @@ def allocate_pobtars( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PPOBTARX algorithms. @@ -56,6 +54,12 @@ def allocate_pobtars( pobtars : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -106,7 +110,7 @@ def allocate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -152,6 +156,7 @@ def map_ppobtax_to_pobtars( buffer: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the the boundary blocks of the PPOBTAX algorithm to the reduced system. @@ -178,6 +183,12 @@ def map_ppobtax_to_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -228,9 +239,16 @@ def map_ppobtas_to_pobtarss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -252,6 +270,7 @@ def aggregate_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -269,7 +288,14 @@ def aggregate_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -306,7 +332,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -317,11 +343,13 @@ def aggregate_pobtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -329,9 +357,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -339,9 +367,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -349,17 +377,20 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -374,7 +405,7 @@ def aggregate_pobtars( ) comm.Allreduce(MPI.IN_PLACE, _A_arrow_tip_block_comm, op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -432,7 +463,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -449,8 +480,15 @@ def aggregate_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -474,7 +512,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -484,10 +522,11 @@ def aggregate_pobtarss( if strategy == "allgather": if _use_nccl(comm): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[:-a], comm=comm, op="allgather" + arr=_B_comm[:-a], comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm[:-a].data.ptr + displacement, recvbuf=_B_comm[:-a].data.ptr, count=count, @@ -495,24 +534,25 @@ def aggregate_pobtarss( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[-a:], comm=comm, op="allreduce" + arr=_B_comm[-a:], comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_comm[-a:].data.ptr, recvbuf=_B_comm[-a:].data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm[:-a], ) comm.Allreduce(MPI.IN_PLACE, _B_comm[-a:], op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -547,7 +587,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -559,10 +599,18 @@ def scatter_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -598,6 +646,11 @@ def scatter_pobtars( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -607,7 +660,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -652,7 +705,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -671,9 +724,17 @@ def scatter_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() b = A_diagonal_blocks[0].shape[0] a = A_arrow_tip_block.shape[0] @@ -695,6 +756,11 @@ def scatter_pobtarss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -703,7 +769,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -727,7 +793,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -747,6 +813,7 @@ def map_pobtars_to_ppobtax( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -772,6 +839,12 @@ def map_pobtars_to_ppobtax( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -824,9 +897,16 @@ def map_pobtarss_to_ppobtas( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index c716317a..391953d4 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -24,6 +24,7 @@ def allocate_pobtrs( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PpobtRX algorithms. @@ -47,6 +48,12 @@ def allocate_pobtrs( pobtrs : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -92,7 +99,7 @@ def allocate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -126,6 +133,7 @@ def map_ppobtx_to_pobtrs( comm: MPI.Comm, buffer: ArrayLike, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: """Map the the boundary blocks of the PpobtX algorithm to the reduced system. @@ -144,6 +152,12 @@ def map_ppobtx_to_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -183,9 +197,16 @@ def map_ppobts_to_pobtrss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -204,6 +225,7 @@ def aggregate_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -215,6 +237,12 @@ def aggregate_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -241,7 +269,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -250,11 +278,13 @@ def aggregate_pobtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -262,16 +292,19 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, datatype=datatype, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -281,7 +314,7 @@ def aggregate_pobtrs( _A_lower_diagonal_blocks_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -322,7 +355,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -336,8 +369,15 @@ def aggregate_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +400,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -369,11 +409,12 @@ def aggregate_pobtrss( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm, comm=comm, op="allgather" + arr=_B_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm.data.ptr + displacement, recvbuf=_B_comm.data.ptr, count=count, @@ -381,12 +422,13 @@ def aggregate_pobtrss( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -415,7 +457,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -427,9 +469,16 @@ def scatter_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -456,6 +505,11 @@ def scatter_pobtrs( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -465,7 +519,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -496,7 +550,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -513,8 +567,15 @@ def scatter_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -536,6 +597,11 @@ def scatter_pobtrss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -544,7 +610,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -564,7 +630,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -580,6 +646,7 @@ def map_pobtrs_to_ppobtx( _A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -597,6 +664,12 @@ def map_pobtrs_to_ppobtx( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -638,9 +711,16 @@ def map_pobtrss_to_ppobts( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 64946a3d..2122b88a 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -20,6 +22,7 @@ def ppobtaf( A_lower_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -62,6 +65,8 @@ def ppobtaf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -133,14 +138,23 @@ def ppobtaf( buffer=buffer, strategy=strategy, comm=comm, + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- pobtars["A_arrow_tip_block"][:] += A_arrow_tip_initial @@ -165,3 +179,5 @@ def ppobtaf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 132a9d08..52896530 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -23,6 +25,7 @@ def ppobtas( L_arrow_tip_block: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -138,16 +141,25 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Agregate reduced RHS + comm.Barrier() + tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, A_arrow_tip_block=L_arrow_tip_block, pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Add the tip block of the RHS to the aggregated update _B[-a:] += B_tip_initial @@ -180,6 +192,7 @@ def ppobtas( pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -190,6 +203,7 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -213,3 +227,5 @@ def ppobtas( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtasi.py b/src/serinv/wrappers/ppobtasi.py index 12e23e14..96b67c20 100644 --- a/src/serinv/wrappers/ppobtasi.py +++ b/src/serinv/wrappers/ppobtasi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtasi @@ -20,6 +19,7 @@ def ppobtasi( L_lower_arrow_blocks: ArrayLike, L_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -114,6 +114,7 @@ def ppobtasi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -129,6 +130,7 @@ def ppobtasi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index 2d680c96..f1b44956 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -18,6 +20,7 @@ def ppobtf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -56,6 +59,8 @@ def ppobtf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -106,14 +111,23 @@ def ppobtf( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- if strategy == "gather-scatter": @@ -134,3 +148,5 @@ def ppobtf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index 396577f1..86906883 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -21,6 +23,7 @@ def ppobts( L_lower_diagonal_blocks: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -119,14 +122,24 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + # Agregate reduced RHS + comm.Barrier() + tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Solve RHS FWD/BWD if strategy == "allgather": @@ -151,6 +164,7 @@ def ppobts( pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -160,6 +174,7 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -179,3 +194,5 @@ def ppobts( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtsi.py b/src/serinv/wrappers/ppobtsi.py index fde320fc..72c7354d 100644 --- a/src/serinv/wrappers/ppobtsi.py +++ b/src/serinv/wrappers/ppobtsi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtsi @@ -18,6 +17,7 @@ def ppobtsi( L_diagonal_blocks: ArrayLike, L_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -100,6 +100,7 @@ def ppobtsi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -111,6 +112,7 @@ def ppobtsi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/tests/tests_algs/permuted/test_bt/test_ddbtsc_permuted.py b/tests/tests_algs/permuted/test_bt/test_ddbtsc_permuted.py index aa3f4f12..96f66f2f 100644 --- a/tests/tests_algs/permuted/test_bt/test_ddbtsc_permuted.py +++ b/tests/tests_algs/permuted/test_bt/test_ddbtsc_permuted.py @@ -4,7 +4,7 @@ import pytest from serinv import _get_module_from_array -from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize +from ....testing_utils import bt_dense_to_arrays, dd_bt from serinv.algs import ddbtsc from serinv.utils import allocate_ddbtx_permutation_buffers @@ -49,8 +49,6 @@ def test_ddbtsc_permuted( dtype=dtype, ) - symmetrize(B) - ( B_diagonal_blocks, B_lower_diagonal_blocks, diff --git a/tests/tests_algs/permuted/test_bt/test_ddbtsci_permuted.py b/tests/tests_algs/permuted/test_bt/test_ddbtsci_permuted.py index f061145c..2f01c3bf 100644 --- a/tests/tests_algs/permuted/test_bt/test_ddbtsci_permuted.py +++ b/tests/tests_algs/permuted/test_bt/test_ddbtsci_permuted.py @@ -12,7 +12,7 @@ @pytest.mark.mpi_skip() @pytest.mark.parametrize("type_of_equation", ["AX=I", "AXA.T=B"]) -def test_ddbtsc_permuted( +def test_ddbtsci_permuted( diagonal_blocksize: int, n_diag_blocks: int, array_type: str, diff --git a/tests/tests_algs/permuted/test_bta/test_ddbtasc_permuted.py b/tests/tests_algs/permuted/test_bta/test_ddbtasc_permuted.py index 8a7e24bb..bcacaeca 100644 --- a/tests/tests_algs/permuted/test_bta/test_ddbtasc_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_ddbtasc_permuted.py @@ -4,7 +4,7 @@ import pytest from serinv import _get_module_from_array -from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize +from ....testing_utils import bta_dense_to_arrays, dd_bta from serinv.algs import ddbtasc @@ -61,8 +61,6 @@ def test_ddbtasc_permuted( dtype=dtype, ) - symmetrize(B) - ( B_diagonal_blocks, B_lower_diagonal_blocks, diff --git a/tests/tests_algs/regular/tests_bt/test_ddbtsc.py b/tests/tests_algs/regular/tests_bt/test_ddbtsc.py index c4204b8a..2f311143 100644 --- a/tests/tests_algs/regular/tests_bt/test_ddbtsc.py +++ b/tests/tests_algs/regular/tests_bt/test_ddbtsc.py @@ -44,8 +44,6 @@ def test_ddbtsc( dtype=dtype, ) - symmetrize(B) - ( B_diagonal_blocks, B_lower_diagonal_blocks, diff --git a/tests/tests_algs/regular/tests_bta/test_ddbtasc.py b/tests/tests_algs/regular/tests_bta/test_ddbtasc.py index e51738b6..32702a6b 100644 --- a/tests/tests_algs/regular/tests_bta/test_ddbtasc.py +++ b/tests/tests_algs/regular/tests_bta/test_ddbtasc.py @@ -4,7 +4,7 @@ import pytest from serinv import _get_module_from_array -from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize +from ....testing_utils import bta_dense_to_arrays, dd_bta from serinv.algs import ddbtasc @@ -45,8 +45,6 @@ def test_ddbtasc( dtype=dtype, ) - symmetrize(B) - ( B_diagonal_blocks, B_lower_diagonal_blocks, diff --git a/tests/tests_wrappers/test_bt/test_pddbtsc.py b/tests/tests_wrappers/test_bt/test_pddbtsc.py index 5f527354..8224a80c 100644 --- a/tests/tests_wrappers/test_bt/test_pddbtsc.py +++ b/tests/tests_wrappers/test_bt/test_pddbtsc.py @@ -48,8 +48,6 @@ def test_pddbtsc( dtype=dtype, ) - symmetrize(A) - xp, _ = _get_module_from_array(A) ( diff --git a/tests/tests_wrappers/test_bt/test_pddbtsci.py b/tests/tests_wrappers/test_bt/test_pddbtsci.py index b4e0dc10..65207423 100644 --- a/tests/tests_wrappers/test_bt/test_pddbtsci.py +++ b/tests/tests_wrappers/test_bt/test_pddbtsci.py @@ -48,8 +48,6 @@ def test_pddbtsc( dtype=dtype, ) - symmetrize(A) - xp, _ = _get_module_from_array(A) (