diff --git a/src/serinv/__init__.py b/src/serinv/__init__.py index ca7f21dd..a8003816 100644 --- a/src/serinv/__init__.py +++ b/src/serinv/__init__.py @@ -25,7 +25,7 @@ # In the case of CuPy, we want to use the lowerfill version # tweaked in serinv. (More performances) - from serinv.cupyfix.cholesky_lowerfill import cholesky_lowerfill as cu_cholesky + from serinv.cupyfix.cholesky_lowerfill import cholesky as cu_cholesky # Check if cupy is actually working. This could still raise # a cudaErrorInsufficientDriver error or something. @@ -163,7 +163,7 @@ def _use_nccl(comm): return False -def _get_nccl_parameters(arr, comm, op: str): +def _get_nccl_parameters(arr, comm, rank, op: str): """Get the NCCL parameters for the given operation.""" if np.iscomplexobj(arr): factor = 2 @@ -172,8 +172,8 @@ def _get_nccl_parameters(arr, comm, op: str): if backend_flags["nccl_avail"]: if op == "allgather": - count = (arr.size // comm.size) * factor - displacement = count * comm.rank * arr.dtype.itemsize + count = (arr.size // comm.size()) * factor + displacement = count * rank * (arr.dtype.itemsize // factor) elif op == "allreduce": count = arr.size * factor displacement = 0 diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 05f58d3b..4993a8b6 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -7,6 +7,7 @@ _get_cholesky, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtaf( A_diagonal_blocks: ArrayLike, @@ -114,22 +115,20 @@ def _pobtaf( # Forward block-Cholesky for i in range(0, n_diag_blocks - 1): # L_{i, i} = chol(A_{i, i}) - L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :]) - + L_diagonal_blocks[i, :, :] = cholesky(A_diagonal_blocks[i, :, :], lower=True) # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) + # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -141,30 +140,38 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + syherk( + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) - # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + syherk( + L_lower_arrow_blocks[i, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) - L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :]) + L_diagonal_blocks[-1, :, :] = cholesky(A_diagonal_blocks[-1, :, :], lower=True) # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} L_lower_arrow_blocks[-1, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, lower=True, @@ -175,12 +182,16 @@ def _pobtaf( # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[-1, :, :] @ L_lower_arrow_blocks[-1, :, :].conj().T + syherk( + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) + # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) - L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :]) + L_arrow_tip_block[:, :] = cholesky(A_arrow_tip_block[:, :], lower=True) def _pobtaf_permuted( @@ -210,18 +221,16 @@ def _pobtaf_permuted( # Compute lower factors # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} buffer[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, lower=True, @@ -232,7 +241,7 @@ def _pobtaf_permuted( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -244,40 +253,60 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + syherk( + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) - # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + syherk( + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + syherk( + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :], + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) @@ -391,13 +420,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -418,7 +445,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -446,24 +473,31 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -488,7 +522,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -509,9 +543,11 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] - @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) @@ -631,7 +667,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) L_diagonal_blocks_d[i % 2, :, :] = cholesky( - A_diagonal_blocks_d[i % 2, :, :] + A_diagonal_blocks_d[i % 2, :, :], lower=True ) cp_diagonal_events[i % 2].record(stream=compute_stream) @@ -654,13 +690,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -681,7 +715,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -701,7 +735,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, @@ -732,48 +766,67 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + # gemm instead of syherk because this somehow kept failing tests in a very weird way + # probably because both sides of the diagonal matrix are used somwhere in a relevant way + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) + # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=False + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - A_arrow_bottom_top_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_arrow_bottom_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + syherk( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=False + ) ) # --- Device 2 Host transfers --- diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bab2a911..4d0f4c8f 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -4,8 +4,10 @@ from serinv import ( ArrayLike, _get_module_from_array, + _get_module_from_str, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtas( L_diagonal_blocks: ArrayLike, @@ -47,8 +49,14 @@ def pobtas( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Streaming is not implemented for the natural arrowhead." + _pobtas_streaming( + L_diagonal_blocks, + L_lower_diagonal_blocks, + L_lower_arrow_blocks, + L_arrow_tip_block, + B, + trans, + partial, ) else: _pobtas( @@ -80,27 +88,41 @@ def _pobtas( if trans == "N": # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True + ) ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + #B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( + # L_lower_diagonal_blocks[i] + # @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + #) + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) - B[-arrow_blocksize:] -= ( - L_lower_arrow_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[-arrow_blocksize:] = ( + gemm( + L_lower_arrow_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[-arrow_blocksize:], + alpha=-1.0, beta=1.0 + ) ) if not partial: # In the case of the partial solve, we do not solve the last block and # arrow tip block of the RHS. B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[n_diag_blocks - 1], B[ (n_diag_blocks - 1) @@ -111,24 +133,33 @@ def _pobtas( ) ) - B[-arrow_blocksize:] -= ( - L_lower_arrow_blocks[-1] - @ B[ - (n_diag_blocks - 1) - * diag_blocksize : n_diag_blocks - * diag_blocksize - ] + B[-arrow_blocksize:] = ( + gemm( + L_lower_arrow_blocks[-1], + B[ + (n_diag_blocks - 1) + * diag_blocksize : n_diag_blocks + * diag_blocksize + ], + B[-arrow_blocksize:], + alpha=-1.0, beta=1.0 + ) ) # Y_{ndb+1} = L_{ndb+1,ndb+1}^{-1} (B_{ndb+1} - \Sigma_{i=1}^{ndb} L_{ndb+1,i} Y_{i) - B[-arrow_blocksize:] = la.solve_triangular( - L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True + B[-arrow_blocksize:] = ( + trsm( + L_arrow_tip_block[:], + B[-arrow_blocksize:], + lower=True + ) ) + elif trans == "T" or trans == "C": # ----- Backward substitution ----- if not partial: # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) - B[-arrow_blocksize:] = la.solve_triangular( + B[-arrow_blocksize:] = trsm( L_arrow_tip_block[:], B[-arrow_blocksize:], lower=True, @@ -137,10 +168,18 @@ def _pobtas( # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] = ( - la.solve_triangular( + gemm( + L_lower_arrow_blocks[-1], + B[-arrow_blocksize:], + B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] = ( + trsm( L_diagonal_blocks[-1], - B[-arrow_blocksize - diag_blocksize : -arrow_blocksize] - - L_lower_arrow_blocks[-1].conj().T @ B[-arrow_blocksize:], + B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], lower=True, trans="C", ) @@ -148,14 +187,31 @@ def _pobtas( for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] - - L_lower_arrow_blocks[i].conj().T @ B[-arrow_blocksize:], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_arrow_blocks[i], + B[-arrow_blocksize:], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") @@ -179,40 +235,533 @@ def _pobtas_permuted( if trans == "N": # ----- Forward substitution ----- for i in range(1, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + ) ) # Update the next RHS block - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) # Update the first RHS block (permutation-linked) - B[:diag_blocksize] -= ( - buffer[i] @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[:diag_blocksize] = ( + gemm( + buffer[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[:diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) # Update the tip RHS block - B[-arrow_blocksize:] -= ( - L_lower_arrow_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[-arrow_blocksize:] = ( + gemm( + L_lower_arrow_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[-arrow_blocksize:], + alpha=-1.0, beta=1.0 + ) ) elif trans == "T" or trans == "C": # ----- Backward substitution ----- for i in range(n_diag_blocks - 2, 0, -1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] - - L_lower_arrow_blocks[i].conj().T @ B[-arrow_blocksize:] - - buffer[i].conj().T @ B[:diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_arrow_blocks[i], + B[-arrow_blocksize:], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + buffer[i], + B[:diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") + + +def _pobtas_streaming( + L_diagonal_blocks: ArrayLike, + L_lower_diagonal_blocks: ArrayLike, + L_lower_arrow_blocks: ArrayLike, + L_arrow_tip_block: ArrayLike, + B: ArrayLike, + trans: str, + partial: bool, +): + arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) + if arr_module.__name__ != "numpy": + raise TypeError( + "Host<->Device streaming only works when host-arrays are given." + ) + + cp, cu_la = _get_module_from_str(module_str="cupy") + + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + arrow_blocksize = L_lower_arrow_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Streams + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + # Device Buffers + # B Buffers + B_shape = B[-arrow_blocksize:] # block template + B_arrow_tip_d = cp.empty_like(B_shape) + + B_shape = B[0:diag_blocksize] + B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_arrow_blocks_d = cp.empty( + (2, *L_lower_arrow_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_arrow_tip_block_d = cp.empty_like(L_arrow_tip_block) + + if trans == "N": + # ----- Forward substitution ----- + # Delete helper variable + del B_shape + + # Events + h2d_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_lower_diagonal_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_arrow_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + d2h_B_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_tip_events = [cp.cuda.Event(), cp.cuda.Event()] + + compute_current_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_next_B_events = [cp.cuda.Event(), cp.cuda.Event()] + compute_arrow_B_events = [cp.cuda.Event(), cp.cuda.Event()] + + compute_partial_events = [cp.cuda.Event(), cp.cuda.Event()] + + # --- C: events + transfers --- + compute_current_B_events[1].record(stream=compute_stream) + compute_next_B_events[1].record(stream=compute_stream) + compute_arrow_B_events[1].record(stream=compute_stream) + + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) + + # --- H2D: transfers --- + B_d[0].set(arr=B[0:diag_blocksize], stream=h2d_stream) + h2d_B_events[0].record(stream=h2d_stream) + + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + h2d_diagonal_events[0].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[0].set(arr=L_lower_arrow_blocks[0], stream=h2d_stream) + h2d_arrow_events[0].record(stream=h2d_stream) + + # --- D2H: event --- + d2h_B_events[1].record(stream=d2h_stream) + + n_diag_blocks: int = L_diagonal_blocks.shape[0] + + if n_diag_blocks > 1: + + L_lower_diagonal_blocks_d[0].set( + arr=L_lower_diagonal_blocks[0], stream=h2d_stream + ) + h2d_lower_diagonal_events[0].record(stream=h2d_stream) + + # --- Computations --- + for i in range(0, n_diag_blocks - 1): + # pass next B block + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=h2d_stream, + ) + + h2d_B_events[(i + 1) % 2].record(stream=h2d_stream) + + if i + 1 < n_diag_blocks - 1: + # pass next diagonal block + h2d_stream.wait_event(compute_current_B_events[(i + 1) % 2]) + L_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_diagonal_blocks[i + 1], stream=h2d_stream + ) + + h2d_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # Solve current B block + compute_stream.wait_event(h2d_diagonal_events[i % 2]) + + B_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + ) + ) + + compute_current_B_events[i % 2].record(stream=compute_stream) + + # Pass current B block back + + if i + 1 < n_diag_blocks - 1: + # Pass next lower diagonal block + h2d_stream.wait_event(compute_next_B_events[(i + 1) % 2]) + L_lower_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_lower_diagonal_blocks[i + 1], stream=h2d_stream + ) + + h2d_lower_diagonal_events[(i + 1) % 2].record(stream=h2d_stream) + + d2h_stream.wait_event(compute_current_B_events[i % 2]) + d2h_stream.wait_event(h2d_lower_diagonal_events[(i + 1) % 2]) + + B_d[i % 2].get( + out=B[i * diag_blocksize : (i + 1) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_B_events[i % 2].record(stream=d2h_stream) + + with compute_stream: + # Update next B block + compute_stream.wait_event(h2d_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_d[i % 2], + B_d[(i + 1) % 2], + alpha=-1.0, beta=1.0 + ) + ) + + compute_next_B_events[i % 2].record(stream=compute_stream) + + if i + 1 < n_diag_blocks - 1: + # Pass next lower arrow block + h2d_stream.wait_event(compute_arrow_B_events[(i + 1) % 2]) + L_lower_arrow_blocks_d[(i + 1) % 2].set( + arr=L_lower_arrow_blocks[i + 1], stream=h2d_stream + ) + + h2d_arrow_events[(i + 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # Update arrow tip + compute_stream.wait_event(h2d_arrow_events[i % 2]) + + B_arrow_tip_d = ( + gemm( + L_lower_arrow_blocks_d[i % 2], + B_d[i % 2], + B_arrow_tip_d, + alpha=-1.0, beta=1.0 + ) + ) + + compute_arrow_B_events[i % 2].record(stream=compute_stream) + + # Pass arrow tip back + d2h_stream.wait_event(compute_arrow_B_events[n_diag_blocks % 2]) + + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + + d2h_tip_events[n_diag_blocks % 2].record(stream=d2h_stream) + + if not partial: + # Pass last blocks + h2d_stream.wait_event(d2h_tip_events[n_diag_blocks % 2]) + + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[n_diag_blocks - 1], stream=h2d_stream + ) + + h2d_diagonal_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_lower_arrow_blocks[-1], stream=h2d_stream + ) + + h2d_arrow_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # Solve last B block + compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) + + B_d[(n_diag_blocks - 1) % 2] = ( + trsm( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + ) + ) + + compute_partial_events[0].record(stream=compute_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_partial_events[0]) + + B_d[(n_diag_blocks - 1) % 2].get( + out=B[ + (n_diag_blocks - 1) + * diag_blocksize : n_diag_blocks + * diag_blocksize + ], + stream=d2h_stream, + blocking=False, + ) + + d2h_B_events[0].record(stream=d2h_stream) + + with compute_stream: + # Solve arrow tip + compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) + + B_arrow_tip_d = ( + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + B_arrow_tip_d, + alpha=-1.0, beta=1.0 + ) + ) + B_arrow_tip_d = ( + trsm( + L_arrow_tip_block_d, + B_arrow_tip_d, + lower=True + ) + ) + + compute_partial_events[1].record(stream=compute_stream) + + d2h_stream.wait_event(compute_partial_events[1]) + + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + + elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + + # Buffers + B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + + # Delete helper variable + del B_shape + + # Events + compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_events = [cp.cuda.Event(), cp.cuda.Event()] + + # --- H2D: transfers --- + B_arrow_tip_d.set(arr=B[-arrow_blocksize:], stream=h2d_stream) + L_arrow_tip_block_d.set(arr=L_arrow_tip_block[:], stream=h2d_stream) + B_d[(n_diag_blocks - 1) % 2].set( + arr=B[-arrow_blocksize - diag_blocksize : -arrow_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[-1], stream=h2d_stream + ) + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_lower_arrow_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + # ----- Backward substitution ----- + if not partial: + + with compute_stream: + # X_{ndb+1} = L_{ndb+1,ndb+1}^{-T} (Y_{ndb+1}) + compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) + B_arrow_tip_d = ( + trsm( + L_arrow_tip_block_d, + B_arrow_tip_d, + lower=True, + trans="C", + ) + ) + + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) + B_previous_d[(n_diag_blocks - 1) % 2] = ( + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2], + B_arrow_tip_d, + B_d[(n_diag_blocks - 1) % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[(n_diag_blocks - 1) % 2] = ( + trsm( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", + ) + ) + + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + + # Pass arrow tip back + d2h_stream.wait_event(compute_B_events[(n_diag_blocks - 1) % 2]) + + B_arrow_tip_d.get( + out=B[-arrow_blocksize:], + stream=d2h_stream, + blocking=False, + ) + + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[ + -arrow_blocksize + - (2 * diag_blocksize) : -arrow_blocksize + - diag_blocksize + ], + stream=h2d_stream, + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_diagonal_blocks[-2], stream=h2d_stream + ) + L_lower_arrow_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_arrow_blocks[-2], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_diagonal_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + for i in range(n_diag_blocks - 2, -1, -1): + + if i > 0: + # Pass new blocks + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_d[(i - 1) % 2].set( + arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_arrow_blocks_d[(i - 1) % 2].set( + arr=L_lower_arrow_blocks[i - 1], stream=h2d_stream + ) + + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + compute_stream.wait_event(h2d_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) + + B_d[i % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_previous_d[(i - 1) % 2], + B_d[i % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_d[i % 2] = ( + gemm( + L_lower_arrow_blocks_d[i % 2], + B_arrow_tip_d, + B_d[i % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + trans="C", + ) + ) + + compute_B_events[i % 2].record(compute_stream) + + # Pass previous B block back + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_previous_d[(i - 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + d2h_events[i % 2].record(stream=d2h_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_B_events[0]) + + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + + else: + raise ValueError(f"Invalid transpose argument: {trans}.") + + cp.cuda.Device().synchronize() diff --git a/src/serinv/algs/pobtasi.py b/src/serinv/algs/pobtasi.py index 7e572d19..4e6c2ab2 100644 --- a/src/serinv/algs/pobtasi.py +++ b/src/serinv/algs/pobtasi.py @@ -6,6 +6,8 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm + def pobtasi( L_diagonal_blocks: ArrayLike, @@ -112,7 +114,7 @@ def _pobtasi( Identity = xp.eye(L_diagonal_blocks.shape[1]) if invert_last_block: - L_last_blk_inv = la.solve_triangular( + L_last_blk_inv = trsm( L_arrow_tip_block[:, :], xp.eye(L_arrow_tip_block.shape[0]), lower=True ) @@ -121,7 +123,7 @@ def _pobtasi( # Backward block-selected inversion L_lower_arrow_blocks_i[:, :] = L_lower_arrow_blocks[-1, :, :] - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[-1, :, :], Identity, lower=True, @@ -142,7 +144,7 @@ def _pobtasi( L_lower_diagonal_blocks_i[:, :] = L_lower_diagonal_blocks[i, :, :] L_lower_arrow_blocks_i[:, :] = L_lower_arrow_blocks[i, :, :] - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[i, :, :], Identity, lower=True, @@ -201,7 +203,7 @@ def _pobtasi_permuted( L_lower_arrow_blocks_temp[:, :] = L_lower_arrow_blocks[i, :, :] buffer_temp[:, :] = buffer[i, :, :] - L_inv_temp[:, :] = la.solve_triangular( + L_inv_temp[:, :] = trsm( L_diagonal_blocks[i, :, :], xp.eye(diag_blocksize), lower=True, @@ -321,7 +323,7 @@ def _pobtasi_streaming( with compute_stream: if invert_last_block: # X_{ndb+1, ndb+1} = L_{ndb+1, ndb}^{-T} L_{ndb+1, ndb}^{-1} - L_last_blk_inv_d = cu_la.solve_triangular( + L_last_blk_inv_d = trsm( L_arrow_tip_block_d[:, :], cp.eye(L_arrow_tip_block.shape[0]), lower=True, @@ -356,7 +358,7 @@ def _pobtasi_streaming( compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) if invert_last_block: # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], Identity, lower=True, @@ -434,7 +436,7 @@ def _pobtasi_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = trsm( L_diagonal_blocks_d[i % 2, :, :], Identity, lower=True, @@ -632,7 +634,7 @@ def _pobtasi_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_inv_temp_d[:, :] = cu_la.solve_triangular( + L_inv_temp_d[:, :] = trsm( L_diagonal_blocks_d[i % 2, :, :], cp.eye(diag_blocksize), lower=True, diff --git a/src/serinv/algs/pobtf.py b/src/serinv/algs/pobtf.py index ed70433b..becc3009 100644 --- a/src/serinv/algs/pobtf.py +++ b/src/serinv/algs/pobtf.py @@ -7,6 +7,7 @@ _get_cholesky, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtf( A_diagonal_blocks: ArrayLike, @@ -100,21 +101,21 @@ def _pobtf( # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + syherk( + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) if factorize_last_block: @@ -145,18 +146,16 @@ def _pobtf_permuted( # Compute lower factors # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} buffer[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, lower=True, @@ -168,20 +167,30 @@ def _pobtf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + syherk( + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + syherk( + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) @@ -276,13 +285,11 @@ def _pobtf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -304,9 +311,11 @@ def _pobtf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + syherk( + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=True + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) @@ -436,13 +445,11 @@ def _pobtf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -456,7 +463,7 @@ def _pobtf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, @@ -487,24 +494,34 @@ def _pobtf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + # gemm instead of syherk because this somehow kept failing tests in a very weird way + # probably because both sides of the diagonal matrix are used somwhere in a relevant way + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + syherk( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + alpha=-1.0, beta=1.0, lower=True, cu_chol=False + ) ) # --- Device 2 Host transfers --- diff --git a/src/serinv/algs/pobts.py b/src/serinv/algs/pobts.py index 99aebc82..684034d7 100644 --- a/src/serinv/algs/pobts.py +++ b/src/serinv/algs/pobts.py @@ -4,8 +4,10 @@ from serinv import ( ArrayLike, _get_module_from_array, + _get_module_from_str, ) +from serinv.block_primitive import trsm, gemm def pobts( L_diagonal_blocks: ArrayLike, @@ -41,8 +43,11 @@ def pobts( else: # Natural arrowhead if device_streaming: - raise NotImplementedError( - "Streaming is not implemented for the natural arrowhead." + _pobts_streaming( + L_diagonal_blocks, + L_lower_diagonal_blocks, + B, + trans, ) else: _pobts( @@ -70,20 +75,26 @@ def _pobts( # ----- Forward substitution ----- for i in range(0, n_diag_blocks - 1): # Y_{i} = L_{i,i}^{-1} (B_{i} - L_{i,i-1} Y_{i-1}) - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True + ) ) - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) if not partial: B[(n_diag_blocks - 1) * diag_blocksize : n_diag_blocks * diag_blocksize] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[n_diag_blocks - 1], B[ (n_diag_blocks - 1) @@ -97,22 +108,33 @@ def _pobts( # ----- Backward substitution ----- if not partial: # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) - B[-diag_blocksize:] = la.solve_triangular( - L_diagonal_blocks[-1], - B[-diag_blocksize:], - lower=True, - trans="C", + B[-diag_blocksize:] = ( + trsm( + L_diagonal_blocks[-1], + B[-diag_blocksize:], + lower=True, + trans="C", + ) ) for i in range(n_diag_blocks - 2, -1, -1): # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) ) else: raise ValueError(f"Invalid transpose argument: {trans}.") @@ -133,33 +155,304 @@ def _pobts_permuted( if trans == "N": # ----- Forward substitution ----- for i in range(1, n_diag_blocks - 1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + ) ) # Update the next RHS block - B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( - L_lower_diagonal_blocks[i] - @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) # Update the first RHS block (permutation-linked) - B[:diag_blocksize] -= ( - buffer[i] @ B[i * diag_blocksize : (i + 1) * diag_blocksize] + B[:diag_blocksize] = ( + gemm( + buffer[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + B[:diag_blocksize], + alpha=-1.0, beta=1.0 + ) ) + elif trans == "T" or trans == "C": # ----- Backward substitution ----- for i in range(n_diag_blocks - 2, 0, -1): - B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( - L_diagonal_blocks[i], - B[i * diag_blocksize : (i + 1) * diag_blocksize] - - L_lower_diagonal_blocks[i].conj().T - @ B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] - - buffer[i].conj().T @ B[:diag_blocksize], - lower=True, - trans="C", + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + L_lower_diagonal_blocks[i], + B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + gemm( + buffer[i], + B[:diag_blocksize], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B[i * diag_blocksize : (i + 1) * diag_blocksize] = ( + trsm( + L_diagonal_blocks[i], + B[i * diag_blocksize : (i + 1) * diag_blocksize], + lower=True, + trans="C", + ) + ) + else: + raise ValueError(f"Invalid transpose argument: {trans}.") + + +def _pobts_streaming( + L_diagonal_blocks: ArrayLike, + L_lower_diagonal_blocks: ArrayLike, + B: ArrayLike, + trans: str, +): + arr_module, _ = _get_module_from_array(arr=L_diagonal_blocks) + if arr_module.__name__ != "numpy": + raise TypeError( + "Host<->Device streaming only works when host-arrays are given." + ) + + cp, cu_la = _get_module_from_str(module_str="cupy") + + # Vars + diag_blocksize = L_diagonal_blocks.shape[1] + n_diag_blocks = L_diagonal_blocks.shape[0] + + # Streams + compute_stream = cp.cuda.Stream(non_blocking=True) + h2d_stream = cp.cuda.Stream(non_blocking=True) + d2h_stream = cp.cuda.Stream(non_blocking=True) + + # Device Buffers + # B Buffers + B_shape = B[0:diag_blocksize] + B_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + B_previous_d = cp.empty((2, *B_shape.shape), dtype=B_shape.dtype) + del B_shape + + # L Buffers + L_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + L_lower_diagonal_blocks_d = cp.empty( + (2, *L_diagonal_blocks.shape[1:]), dtype=L_diagonal_blocks.dtype + ) + + # Events + compute_B_events = [cp.cuda.Event(), cp.cuda.Event()] + h2d_events = [cp.cuda.Event(), cp.cuda.Event()] + d2h_events = [cp.cuda.Event(), cp.cuda.Event()] + + if trans == "N": + # ----- Forward substitution ----- + + # --- H2D: transfers --- + B_d[0].set(arr=B[:diag_blocksize], stream=h2d_stream) + L_diagonal_blocks_d[0].set(arr=L_diagonal_blocks[0], stream=h2d_stream) + + h2d_events[1].record(stream=h2d_stream) + + if n_diag_blocks > 1: + B_d[1].set(arr=B[diag_blocksize : (2 * diag_blocksize)], stream=h2d_stream) + L_diagonal_blocks_d[1].set(arr=L_diagonal_blocks[1], stream=h2d_stream) + L_lower_diagonal_blocks_d[1].set( + arr=L_lower_diagonal_blocks[0], stream=h2d_stream + ) + + h2d_events[0].record(stream=h2d_stream) + + with compute_stream: + # Solve first B block + compute_stream.wait_event(h2d_events[1]) + + B_previous_d[0] = ( + trsm( + L_diagonal_blocks_d[0], + B_d[0], + lower=True, + ) + ) + + compute_B_events[0].record(stream=compute_stream) + + for i in range(1, n_diag_blocks): + + if i + 1 < n_diag_blocks: + # Pass next blocks + h2d_stream.wait_event(compute_B_events[(i + 1) % 2]) + + B_d[(i + 1) % 2].set( + arr=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_diagonal_blocks[i + 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i + 1) % 2].set( + arr=L_lower_diagonal_blocks[i], stream=h2d_stream + ) + + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + compute_stream.wait_event(h2d_events[(i + 1) % 2]) + compute_stream.wait_event(d2h_events[(i + 1) % 2]) + + B_d[i % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_previous_d[(i + 1) % 2], + B_d[i % 2], + alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + ) + ) + + compute_B_events[i % 2].record(compute_stream) + + # Pass previous B block back + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_previous_d[(i + 1) % 2].get( + out=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_events[i % 2].record(stream=d2h_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_B_events[(n_diag_blocks + 1) % 2]) + + B_previous_d[(n_diag_blocks + 1) % 2].get( + out=B[-diag_blocksize:], stream=d2h_stream, blocking=False + ) + + elif trans == "T" or trans == "C": + # ----- Backward substitution ----- + + # --- H2D: transfers --- + B_d[(n_diag_blocks - 1) % 2].set(arr=B[-diag_blocksize:], stream=h2d_stream) + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2].set( + arr=L_diagonal_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + if n_diag_blocks > 1: + + B_d[n_diag_blocks % 2].set( + arr=B[-(2 * diag_blocksize) : -diag_blocksize], stream=h2d_stream + ) + L_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_diagonal_blocks[-2], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[n_diag_blocks % 2].set( + arr=L_lower_diagonal_blocks[-1], stream=h2d_stream + ) + + h2d_events[(n_diag_blocks - 1) % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{ndb} = L_{ndb,ndb}^{-T} (Y_{ndb} - L_{ndb+1,ndb}^{T} X_{ndb+1}) + compute_stream.wait_event(h2d_events[(n_diag_blocks - 1) % 2]) + + B_previous_d[(n_diag_blocks - 1) % 2] = ( + trsm( + L_diagonal_blocks_d[(n_diag_blocks - 1) % 2], + B_d[(n_diag_blocks - 1) % 2], + lower=True, + trans="C", + ) ) + + compute_B_events[(n_diag_blocks - 1) % 2].record(stream=compute_stream) + + for i in range(n_diag_blocks - 2, -1, -1): + + if i > 0: + # pass next blocks + h2d_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_d[(i - 1) % 2].set( + arr=B[(i - 1) * diag_blocksize : i * diag_blocksize], + stream=h2d_stream, + ) + L_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_diagonal_blocks[i - 1], stream=h2d_stream + ) + L_lower_diagonal_blocks_d[(i - 1) % 2].set( + arr=L_lower_diagonal_blocks[i - 1], stream=h2d_stream + ) + + h2d_events[i % 2].record(stream=h2d_stream) + + with compute_stream: + # X_{i} = L_{i,i}^{-T} (Y_{i} - L_{i+1,i}^{T} X_{i+1}) - L_{ndb+1,i}^T X_{ndb+1} + compute_stream.wait_event(h2d_events[(i - 1) % 2]) + compute_stream.wait_event(d2h_events[(i - 1) % 2]) + + B_d[i % 2] = ( + gemm( + L_lower_diagonal_blocks_d[i % 2], + B_previous_d[(i - 1) % 2], + B_d[i % 2], + trans_a='C', alpha=-1.0, beta=1.0 + ) + ) + + B_previous_d[i % 2] = ( + trsm( + L_diagonal_blocks_d[i % 2], + B_d[i % 2], + lower=True, + trans="C", + ) + ) + + compute_B_events[i % 2].record(compute_stream) + + # Pass previous B block back + d2h_stream.wait_event(compute_B_events[(i - 1) % 2]) + + B_previous_d[(i - 1) % 2].get( + out=B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize], + stream=d2h_stream, + blocking=False, + ) + + d2h_events[i % 2].record(stream=d2h_stream) + + # Pass last B block back + d2h_stream.wait_event(compute_B_events[0]) + + B_previous_d[0].get(out=B[:diag_blocksize], stream=d2h_stream, blocking=False) + else: raise ValueError(f"Invalid transpose argument: {trans}.") + + cp.cuda.Device().synchronize() diff --git a/src/serinv/algs/pobtsi.py b/src/serinv/algs/pobtsi.py index 0d2d1e6a..ec7294f3 100644 --- a/src/serinv/algs/pobtsi.py +++ b/src/serinv/algs/pobtsi.py @@ -6,6 +6,7 @@ _get_module_from_str, ) +from serinv.block_primitive import trsm def pobtsi( L_diagonal_blocks: ArrayLike, @@ -92,7 +93,7 @@ def _pobtsi( Identity = xp.eye(L_diagonal_blocks.shape[1]) if invert_last_block: - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[-1, :, :], Identity, lower=True, @@ -104,7 +105,7 @@ def _pobtsi( for i in range(n_diag_blocks - 2, -1, -1): L_lower_diagonal_blocks_i[:, :] = L_lower_diagonal_blocks[i, :, :] - L_blk_inv = la.solve_triangular( + L_blk_inv = trsm( L_diagonal_blocks[i, :, :], Identity, lower=True, @@ -148,7 +149,7 @@ def _pobtsi_permuted( L_lower_diagonal_blocks_temp[:, :] = L_lower_diagonal_blocks[i, :, :] buffer_temp[:, :] = buffer[i, :, :] - L_inv_temp[:, :] = la.solve_triangular( + L_inv_temp[:, :] = trsm( L_diagonal_blocks[i, :, :], xp.eye(diag_blocksize), lower=True, @@ -245,7 +246,7 @@ def _pobtsi_streaming( compute_stream.wait_event(h2d_diagonal_events[(n_diag_blocks - 1) % 2]) if invert_last_block: # X_{ndb+1, ndb} = -X_{ndb+1, ndb+1} L_{ndb+1, ndb} L_{ndb, ndb}^{-1} - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], Identity, lower=True, @@ -289,7 +290,7 @@ def _pobtsi_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_blk_inv_d = cu_la.solve_triangular( + L_blk_inv_d = trsm( L_diagonal_blocks_d[i % 2, :, :], Identity, lower=True, @@ -435,7 +436,7 @@ def _pobtsi_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_diagonal_events[i % 2]) - L_inv_temp_d[:, :] = cu_la.solve_triangular( + L_inv_temp_d[:, :] = trsm( L_diagonal_blocks_d[i % 2, :, :], cp.eye(diag_blocksize), lower=True, diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py new file mode 100644 index 00000000..687c9c8f --- /dev/null +++ b/src/serinv/block_primitive/__init__.py @@ -0,0 +1,11 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. + +from serinv.block_primitive.gemm import gemm +from serinv.block_primitive.trsm import trsm +from serinv.block_primitive.syherk import syherk + +__all__ = [ + "gemm", + "trsm", + "syherk" +] \ No newline at end of file diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py new file mode 100644 index 00000000..c474c72a --- /dev/null +++ b/src/serinv/block_primitive/gemm.py @@ -0,0 +1,239 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. +# Forked and modified from cupy.cublas.gemm: https://github.com/cupy/cupy/blob/3a2c950d64ee707096bc7ca1bf0b953a08206384/cupy/cublas.py#L689 +# and scipy.linal.solve_triangular: https://github.com/scipy/scipy/blob/v1.15.3/scipy/linalg/_basic.py#L411 + +from serinv import _get_module_from_array + +import numpy as np + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy_backends.cuda.libs import cublas + from cupy import _core + from cupy.cuda import device +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): + """Wrapper to call GeMM for host or device""" + xp, la = _get_module_from_array(a) + + if xp == np: + return matmul_gemm_host(a, b, alpha, beta, c, trans_a, trans_b) + elif xp == cp: + return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) + else: + ModuleNotFoundError("Unknown Module") + + +def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): + """Computes out = alpha * op(a) @ op(b) + beta * out + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'C'. + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + if c is None: + c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) + + transa = True + transb = True + if trans_a == 'N': + transa = False + if trans_b == 'N': + transb = False + + if not transa and not transb: + if a1.shape[1] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,0)') + + elif transa and not transb: + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,0)') + + elif not transa and transb: + if a1.shape[1] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (1,1)') + + else: + if a1.shape[0] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible (0,1)') + + if beta != 0 and c1 is None: + raise ValueError('expected C matrix') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = matmul_gemm_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + if c1 is not None: + overwrite_c = overwrite_c or _datacopied(c1, c) + + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) + return x + + +# gemm without the input validation +def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): + + trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) + trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) + gemm, = get_blas_funcs(('gemm',), (a1, b1)) + + if beta == 0: + out = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + else: + out = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) + + + return out + + +# Util functions for cupy gemm +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + + +def _change_order_if_necessary(a, lda): + if lda is None: + lda = a.shape[0] + if not a._f_contiguous: + a = a.copy(order='F') + return a, lda + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr +# Util functions for cupy gemm end + + +def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): + """Computes out = alpha * op(a) @ op(b) + beta * out + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'C'. + """ + assert a.ndim == b.ndim == 2 + assert a.dtype == b.dtype + dtype = a.dtype.char + if dtype == 'f': + func = cublas.sgemm + elif dtype == 'd': + func = cublas.dgemm + elif dtype == 'F': + func = cublas.cgemm + elif dtype == 'D': + func = cublas.zgemm + else: + raise TypeError('invalid dtype') + + + transa = _trans_to_cublas_op(transa) + transb = _trans_to_cublas_op(transb) + if transa == cublas.CUBLAS_OP_N: + m, k = a.shape + else: + k, m = a.shape + if transb == cublas.CUBLAS_OP_N: + n = b.shape[1] + assert b.shape[0] == k + else: + n = b.shape[0] + assert b.shape[1] == k + if out is None: + out = cp.empty((m, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (m, n) + assert out.dtype == dtype + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, transa = _decide_ld_and_trans(a, transa) + ldb, transb = _decide_ld_and_trans(b, transb) + if not (lda is None or ldb is None): + if out._f_contiguous: + try: + func(handle, transa, transb, m, n, k, alpha_ptr, + a.data.ptr, lda, b.data.ptr, ldb, beta_ptr, out.data.ptr, + m) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + elif out._c_contiguous: + # Computes out.T = alpha * b.T @ a.T + beta * out.T + try: + func(handle, 1 - transb, 1 - transa, n, m, k, alpha_ptr, + b.data.ptr, ldb, a.data.ptr, lda, beta_ptr, out.data.ptr, + n) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + + a, lda = _change_order_if_necessary(a, lda) + b, ldb = _change_order_if_necessary(b, ldb) + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, transa, transb, m, n, k, alpha_ptr, a.data.ptr, lda, + b.data.ptr, ldb, beta_ptr, c.data.ptr, m) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + _core.elementwise_copy(c, out) + return out \ No newline at end of file diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py new file mode 100644 index 00000000..75367b3c --- /dev/null +++ b/src/serinv/block_primitive/syherk.py @@ -0,0 +1,204 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. +# Forked and modified from cupy.cublas.syrk: https://github.com/cupy/cupy/blob/3a2c950d64ee707096bc7ca1bf0b953a08206384/cupy/cublas.py#L930 +# and scipy.linal.solve_triangular: https://github.com/scipy/scipy/blob/v1.15.3/scipy/linalg/_basic.py#L411 + +from serinv import _get_module_from_array + +from serinv.block_primitive import gemm + +import numpy as np + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy_backends.cuda.libs import cublas + from cupy.cuda import device +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower=False, cu_chol=False): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp, la = _get_module_from_array(a) + if xp == np: + return matmul_syherk_host(a, c, alpha, beta, trans, lower) + elif xp == cp: + return matmul_syherk_device(a, trans, c, alpha, beta, lower, cu_chol) + else: + ModuleNotFoundError("Unknown Module") + +def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, + overwrite_c=False, check_finite=True,): + """Computes out = alpha * op(a) @ op(a)^T + beta * b + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + if c is None: + c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) + + overwrite_c = overwrite_c or _datacopied(c1, c) + + x = _syherk(a1, c1, alpha, beta, trans, lower, overwrite_c) + return x + + +# syherk without the input validation +def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, + overwrite_c=False): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + + if np.iscomplexobj(a1): + syherk = get_blas_funcs(('herk'), (a1, a1)) + else: + syherk = get_blas_funcs(('syrk'), (a1, a1)) + + out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) + + return out + + + +# Util functions for cupy gemm +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr +# Util functions for cupy gemm end + +def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=False, cu_chol=False): + """Computes out := alpha*op1(a)*op2(a) + beta*out + + op1(a) = a if trans is 'N', op2(a) = a.T if transa is 'N' + op1(a) = a.T if trans is 'T', op2(a) = a if transa is 'T' + lower specifies whether the upper or lower triangular + part of the array out is to be referenced + """ + assert a.ndim == 2 + dtype = a.dtype.char + if dtype == 'f': + func = cublas.ssyrk + elif dtype == 'd': + func = cublas.dsyrk + elif dtype == 'F': + try: + func = cublas.cherk + except(AttributeError): + out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta) + return out + elif dtype == 'D': + try: + func = cublas.zherk + except(AttributeError): + out = gemm(a, a, out, trans_b='C', alpha=alpha, beta=beta) + return out + else: + raise TypeError('invalid dtype') + + # If this is run in combination with cholesky, it will be necessary to flip lower + if cu_chol: + lower = not lower + + trans = _trans_to_cublas_op(trans) + if trans == cublas.CUBLAS_OP_N: + n, k = a.shape + else: + k, n = a.shape + if out is None: + out = cp.zeros((n, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (n, n) + assert out.dtype == dtype + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, trans = _decide_ld_and_trans(a, trans) + ldo, _ = _decide_ld_and_trans(out, trans) + + if out._c_contiguous: + if not a._c_contiguous: + a = a.copy(order='C') + trans = 1 - trans + lda = a.shape[1] + try: + func(handle, 1 - uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + + else: + if not a._f_contiguous: + a = a.copy(order='F') + lda = a.shape[0] + trans = 1 - trans + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + out[...] = c + return out \ No newline at end of file diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py new file mode 100644 index 00000000..2e660e39 --- /dev/null +++ b/src/serinv/block_primitive/trsm.py @@ -0,0 +1,323 @@ +# Copyright 2023-2025 ETH Zurich. All rights reserved. +# Forked and modified from cupyx.linalg.solve_triangular: https://github.com/cupy/cupy/blob/3a2c950d64ee707096bc7ca1bf0b953a08206384/cupyx/scipy/linalg/_solve_triangular.py#L12 +# and scipy.linal.solve_triangular: https://github.com/scipy/scipy/blob/v1.15.3/scipy/linalg/_basic.py#L411 + +import numpy as np + +from serinv import _get_module_from_array + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy.cuda import cublas + from cupy.cuda import device + from cupy.linalg import _util +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def trsm(a, b, trans=0, lower = False, unit_diagonal=False, + overwrite_b=True, check_finite=False, side=0): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp, la = _get_module_from_array(a) + if xp == np: + return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + elif xp == cp: + return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + else: + ModuleNotFoundError("Unknown Module") + + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if a.ndim == 2: + if a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + batch_count = 0 + elif a.ndim > 2: + if a.shape[-1] != a.shape[-2]: + raise ValueError('expected a batch of square matrices') + if a.shape[:-2] != b.shape[:a.ndim - 2]: + raise ValueError('incompatible batch count') + if b.ndim < a.ndim - 1 or a.shape[-2] != b.shape[a.ndim - 2]: + raise ValueError('incompatible dimensions') + batch_count = math.prod(a.shape[:-2]) + else: + raise ValueError( + 'expected one square matrix or a batch of square matrices') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = np.promote_types(a.dtype.char, 'f') + + if check_finite: + if a.dtype.kind == 'f' and not cp.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cp.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + if batch_count: + m, n = b.shape[-2:] if b.ndim == a.ndim else (b.shape[-1], 1) + + a_new_shape = (batch_count, m, m) + b_shape = b.shape + b_data_ptr = b.data.ptr + # trsm receives Fortran array, but we want zero copy + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + # normal Fortran upper == transpose C lower + trans = cublas.CUBLAS_OP_T + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + # transpose Fortran upper == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + if dtype == 'f' or dtype == 'd': + # real numbers + # Hermitian Fortran upper == transpose Fortran upper + # == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), + dtype=dtype) + else: + # complex numbers + trans = cublas.CUBLAS_OP_C + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + else: # know nothing about `trans`, just convert C to Fortran + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + b = cp.ascontiguousarray( + b.reshape(batch_count, m, n).transpose(0, 2, 1), dtype=dtype) + if b.data.ptr == b_data_ptr and not overwrite_b: + b = b.copy() + + start = a.data.ptr + step = m * m * a.itemsize + stop = start + step * batch_count + a_array = cp.arange(start, stop, step, dtype=cp.uintp) + + start = b.data.ptr + step = m * n * b.itemsize + stop = start + step * batch_count + b_array = cp.arange(start, stop, step, dtype=cp.uintp) + else: + a = cp.array(a, dtype=dtype, order='F', copy=None) + b = cp.array(b, dtype=dtype, order='F', + copy=(None if overwrite_b else True)) + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + cublas_handle = device.get_cublas_handle() + one = np.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + side = cublas.CUBLAS_SIDE_RIGHT + else: + side = cublas.CUBLAS_SIDE_LEFT + + if batch_count: + if dtype == 'f': + trsm = cublas.strsmBatched + elif dtype == 'd': + trsm = cublas.dtrsmBatched + elif dtype == 'F': + trsm = cublas.ctrsmBatched + else: # dtype == 'D' + trsm = cublas.ztrsmBatched + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a_array.data.ptr, m, + b_array.data.ptr, m, batch_count) + return b.transpose(0, 2, 1).reshape(b_shape) + else: + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + return b + + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + alpha = 1.0 + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trsm expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x \ No newline at end of file diff --git a/src/serinv/cupyfix/__init__.py b/src/serinv/cupyfix/__init__.py index 60f6f426..37015a0c 100644 --- a/src/serinv/cupyfix/__init__.py +++ b/src/serinv/cupyfix/__init__.py @@ -1,7 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. -from serinv.cupyfix.cholesky_lowerfill import cholesky_lowerfill +from serinv.cupyfix.cholesky_lowerfill import cholesky __all__ = [ - "cholesky_lowerfill", + "cholesky", ] diff --git a/src/serinv/cupyfix/cholesky_lowerfill.py b/src/serinv/cupyfix/cholesky_lowerfill.py index e4778c35..a5d053e0 100644 --- a/src/serinv/cupyfix/cholesky_lowerfill.py +++ b/src/serinv/cupyfix/cholesky_lowerfill.py @@ -7,7 +7,7 @@ from cupy.linalg import _util -def cholesky_lowerfill(a: cupy.ndarray) -> cupy.ndarray: +def cholesky(a: cupy.ndarray, lower=True) -> cupy.ndarray: """Cholesky decomposition. Decompose a given two-dimensional square matrix into ``L * L.H``, @@ -49,6 +49,11 @@ def cholesky_lowerfill(a: cupy.ndarray) -> cupy.ndarray: handle = device.get_cusolver_handle() dev_info = cupy.empty(1, dtype=numpy.int32) + if lower: + lower = cublas.CUBLAS_FILL_MODE_LOWER + else: + lower = cublas.CUBLAS_FILL_MODE_UPPER + if dtype == "f": potrf = cusolver.spotrf potrf_bufferSize = cusolver.spotrf_bufferSize @@ -68,7 +73,7 @@ def cholesky_lowerfill(a: cupy.ndarray) -> cupy.ndarray: workspace = cupy.empty(buffersize, dtype=dtype) potrf( handle, - cublas.CUBLAS_FILL_MODE_LOWER, + lower, n, x.data.ptr, n, diff --git a/src/serinv/utils/__init__.py b/src/serinv/utils/__init__.py index 4079f81c..1c54d228 100644 --- a/src/serinv/utils/__init__.py +++ b/src/serinv/utils/__init__.py @@ -8,6 +8,8 @@ from serinv.utils.pobtx import allocate_pobtx_permutation_buffers from serinv.utils.pobtax import allocate_pobtax_permutation_buffers + + __all__ = [ "check_block_dd", "check_ddbta", diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index 46be5972..17a12993 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -15,10 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, @@ -30,7 +26,14 @@ def allocate_ddbtars( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -80,7 +83,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -144,7 +147,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -224,8 +227,15 @@ def map_ddbtasc_to_ddbtars( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str, + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +370,14 @@ def aggregate_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -455,7 +472,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -564,11 +581,12 @@ def aggregate_ddbtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -576,9 +594,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -586,9 +604,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -596,9 +614,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -606,9 +624,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_arrow_blocks_comm.data.ptr, count=count, @@ -616,17 +634,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -662,11 +681,12 @@ def aggregate_ddbtars( ddbtars["A_upper_arrow_blocks"] = _A_upper_arrow_blocks[1:] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -674,9 +694,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -684,9 +704,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -694,9 +714,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_arrow_blocks_comm.data.ptr, count=count, @@ -704,9 +724,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_arrow_blocks_comm.data.ptr, count=count, @@ -714,17 +734,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_B_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_arrow_tip_block_comm.data.ptr, recvbuf=_B_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -763,7 +784,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -787,8 +808,18 @@ def scatter_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() + _A_diagonal_blocks: ArrayLike = ddbtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = ddbtars.get("A_lower_diagonal_blocks", None) _A_upper_diagonal_blocks: ArrayLike = ddbtars.get("A_upper_diagonal_blocks", None) @@ -857,8 +888,15 @@ def map_ddbtars_to_ddbtasci( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ddbtrs.py b/src/serinv/wrappers/ddbtrs.py index 15393ae9..937d8def 100644 --- a/src/serinv/wrappers/ddbtrs.py +++ b/src/serinv/wrappers/ddbtrs.py @@ -24,7 +24,14 @@ def allocate_ddbtrs( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -57,7 +64,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -98,7 +105,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -150,8 +157,15 @@ def map_ddbtsc_to_ddbtrs( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -241,7 +255,14 @@ def aggregate_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -304,7 +325,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -379,11 +400,12 @@ def aggregate_ddbtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -391,9 +413,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -401,9 +423,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -411,6 +433,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -433,11 +456,12 @@ def aggregate_ddbtrs( ddbtrs["A_upper_diagonal_blocks"] = _A_upper_diagonal_blocks[1:-2] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -445,9 +469,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -455,9 +479,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -465,6 +489,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -492,7 +517,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -510,8 +535,15 @@ def scatter_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -565,8 +597,15 @@ def map_ddbtrs_to_ddbtsci( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index da2fcff4..27335195 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -22,6 +24,7 @@ def pddbtasc( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -115,9 +118,10 @@ def pddbtasc( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) - buffers: dict = kwargs.get("buffers", None) ddbtars: dict = kwargs.get("ddbtars", None) strategy: str = kwargs.get("strategy", "allgather") @@ -200,14 +204,23 @@ def pddbtasc( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic ddbtars["A_arrow_tip_block"][:] += A_arrow_tip_initial if quadratic: @@ -226,3 +239,5 @@ def pddbtasc( ) comm.Barrier() + + return elapsed diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f86b6a9c..3ad34f79 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -22,6 +22,7 @@ def pddbtasci( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -43,7 +44,7 @@ def pddbtasci( The arrow tip block of the block tridiagonal with arrowhead matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict @@ -159,6 +160,7 @@ def pddbtasci( ddbtars=ddbtars, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtars_to_ddbtasci( @@ -179,6 +181,7 @@ def pddbtasci( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index fc0a3765..80d18545 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -19,6 +21,7 @@ def pddbtsc( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -82,6 +85,7 @@ def pddbtsc( - _B_upper_diagonal_blocks : ArrayLike The upper diagonal blocks of the reduced system. """ + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -161,14 +165,23 @@ def pddbtsc( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Perform Schur complement on the reduced system ddbtsc( @@ -179,4 +192,4 @@ def pddbtsc( quadratic=quadratic, ) - comm.Barrier() + comm.Barrier() \ No newline at end of file diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 144054f8..18d5ea00 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import ddbtsci @@ -19,6 +18,7 @@ def pddbtsci( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -34,7 +34,7 @@ def pddbtsci( The upper diagonal blocks of the block tridiagonal matrix. comm : MPI.Comm The MPI communicator. Default is MPI.COMM_WORLD. - + Keyword Arguments ----------------- rhs : dict @@ -89,8 +89,6 @@ def pddbtsci( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") - xp, _ = _get_module_from_array(arr=A_diagonal_blocks) - rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) buffers: dict = kwargs.get("buffers", None) @@ -125,6 +123,7 @@ def pddbtsci( ddbtrs=ddbtrs, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtrs_to_ddbtsci( @@ -139,6 +138,7 @@ def pddbtsci( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index 98cfdf46..ee6f26bb 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_pobtars( A_diagonal_blocks: ArrayLike, @@ -29,6 +26,7 @@ def allocate_pobtars( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PPOBTARX algorithms. @@ -56,6 +54,12 @@ def allocate_pobtars( pobtars : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -106,7 +110,7 @@ def allocate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -152,6 +156,7 @@ def map_ppobtax_to_pobtars( buffer: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the the boundary blocks of the PPOBTAX algorithm to the reduced system. @@ -178,6 +183,12 @@ def map_ppobtax_to_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -228,9 +239,16 @@ def map_ppobtas_to_pobtarss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -252,6 +270,7 @@ def aggregate_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -269,7 +288,14 @@ def aggregate_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -306,7 +332,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -317,11 +343,13 @@ def aggregate_pobtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -329,9 +357,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -339,9 +367,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -349,17 +377,20 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -374,7 +405,7 @@ def aggregate_pobtars( ) comm.Allreduce(MPI.IN_PLACE, _A_arrow_tip_block_comm, op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -432,7 +463,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -449,8 +480,15 @@ def aggregate_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -474,7 +512,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -484,10 +522,11 @@ def aggregate_pobtarss( if strategy == "allgather": if _use_nccl(comm): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[:-a], comm=comm, op="allgather" + arr=_B_comm[:-a], comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm[:-a].data.ptr + displacement, recvbuf=_B_comm[:-a].data.ptr, count=count, @@ -495,24 +534,25 @@ def aggregate_pobtarss( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[-a:], comm=comm, op="allreduce" + arr=_B_comm[-a:], comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_comm[-a:].data.ptr, recvbuf=_B_comm[-a:].data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm[:-a], ) comm.Allreduce(MPI.IN_PLACE, _B_comm[-a:], op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -547,7 +587,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -559,10 +599,18 @@ def scatter_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -598,6 +646,11 @@ def scatter_pobtars( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -607,7 +660,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -652,7 +705,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -671,9 +724,17 @@ def scatter_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() b = A_diagonal_blocks[0].shape[0] a = A_arrow_tip_block.shape[0] @@ -695,6 +756,11 @@ def scatter_pobtarss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -703,7 +769,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -727,7 +793,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -747,6 +813,7 @@ def map_pobtars_to_ppobtax( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -772,6 +839,12 @@ def map_pobtars_to_ppobtax( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -824,9 +897,16 @@ def map_pobtarss_to_ppobtas( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index c716317a..391953d4 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -24,6 +24,7 @@ def allocate_pobtrs( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PpobtRX algorithms. @@ -47,6 +48,12 @@ def allocate_pobtrs( pobtrs : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -92,7 +99,7 @@ def allocate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -126,6 +133,7 @@ def map_ppobtx_to_pobtrs( comm: MPI.Comm, buffer: ArrayLike, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: """Map the the boundary blocks of the PpobtX algorithm to the reduced system. @@ -144,6 +152,12 @@ def map_ppobtx_to_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -183,9 +197,16 @@ def map_ppobts_to_pobtrss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -204,6 +225,7 @@ def aggregate_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -215,6 +237,12 @@ def aggregate_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -241,7 +269,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -250,11 +278,13 @@ def aggregate_pobtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -262,16 +292,19 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, datatype=datatype, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -281,7 +314,7 @@ def aggregate_pobtrs( _A_lower_diagonal_blocks_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -322,7 +355,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -336,8 +369,15 @@ def aggregate_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +400,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -369,11 +409,12 @@ def aggregate_pobtrss( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm, comm=comm, op="allgather" + arr=_B_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm.data.ptr + displacement, recvbuf=_B_comm.data.ptr, count=count, @@ -381,12 +422,13 @@ def aggregate_pobtrss( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -415,7 +457,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -427,9 +469,16 @@ def scatter_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -456,6 +505,11 @@ def scatter_pobtrs( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -465,7 +519,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -496,7 +550,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -513,8 +567,15 @@ def scatter_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -536,6 +597,11 @@ def scatter_pobtrss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -544,7 +610,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -564,7 +630,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -580,6 +646,7 @@ def map_pobtrs_to_ppobtx( _A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -597,6 +664,12 @@ def map_pobtrs_to_ppobtx( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -638,9 +711,16 @@ def map_pobtrss_to_ppobts( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 64946a3d..2122b88a 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -20,6 +22,7 @@ def ppobtaf( A_lower_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -62,6 +65,8 @@ def ppobtaf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -133,14 +138,23 @@ def ppobtaf( buffer=buffer, strategy=strategy, comm=comm, + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- pobtars["A_arrow_tip_block"][:] += A_arrow_tip_initial @@ -165,3 +179,5 @@ def ppobtaf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 132a9d08..52896530 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -23,6 +25,7 @@ def ppobtas( L_arrow_tip_block: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -138,16 +141,25 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Agregate reduced RHS + comm.Barrier() + tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, A_arrow_tip_block=L_arrow_tip_block, pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Add the tip block of the RHS to the aggregated update _B[-a:] += B_tip_initial @@ -180,6 +192,7 @@ def ppobtas( pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -190,6 +203,7 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -213,3 +227,5 @@ def ppobtas( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtasi.py b/src/serinv/wrappers/ppobtasi.py index 12e23e14..96b67c20 100644 --- a/src/serinv/wrappers/ppobtasi.py +++ b/src/serinv/wrappers/ppobtasi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtasi @@ -20,6 +19,7 @@ def ppobtasi( L_lower_arrow_blocks: ArrayLike, L_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -114,6 +114,7 @@ def ppobtasi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -129,6 +130,7 @@ def ppobtasi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index 2d680c96..f1b44956 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -18,6 +20,7 @@ def ppobtf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -56,6 +59,8 @@ def ppobtf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -106,14 +111,23 @@ def ppobtf( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- if strategy == "gather-scatter": @@ -134,3 +148,5 @@ def ppobtf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index 396577f1..86906883 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -21,6 +23,7 @@ def ppobts( L_lower_diagonal_blocks: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -119,14 +122,24 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + # Agregate reduced RHS + comm.Barrier() + tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Solve RHS FWD/BWD if strategy == "allgather": @@ -151,6 +164,7 @@ def ppobts( pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -160,6 +174,7 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -179,3 +194,5 @@ def ppobts( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtsi.py b/src/serinv/wrappers/ppobtsi.py index fde320fc..72c7354d 100644 --- a/src/serinv/wrappers/ppobtsi.py +++ b/src/serinv/wrappers/ppobtsi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtsi @@ -18,6 +17,7 @@ def ppobtsi( L_diagonal_blocks: ArrayLike, L_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -100,6 +100,7 @@ def ppobtsi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -111,6 +112,7 @@ def ppobtsi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/tests/conftest.py b/tests/conftest.py index 3e624933..25b716a6 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,6 +1,5 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. # Global pytest fixtures for the Serinv tests. - import pytest from serinv import backend_flags @@ -15,7 +14,6 @@ ] ) - DTYPE = [ pytest.param("float64", id="float64"), pytest.param("complex128", id="complex128"), diff --git a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py index dc6e3a25..e4b9fc40 100644 --- a/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py +++ b/tests/tests_algs/permuted/test_bta/test_pobtasi_permuted.py @@ -3,15 +3,28 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize from serinv.utils import allocate_pobtax_permutation_buffers from serinv.algs import pobtaf, pobtasi +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtasi_permuted( diff --git a/tests/tests_algs/regular/tests_bt/test_pobtf.py b/tests/tests_algs/regular/tests_bt/test_pobtf.py index d1969b05..0ac3ae89 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobtf.py +++ b/tests/tests_algs/regular/tests_bt/test_pobtf.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize @@ -11,6 +13,16 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtf( diff --git a/tests/tests_algs/regular/tests_bt/test_pobts.py b/tests/tests_algs/regular/tests_bt/test_pobts.py index 8125df52..9011caa1 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobts.py +++ b/tests/tests_algs/regular/tests_bt/test_pobts.py @@ -3,11 +3,26 @@ import numpy as np import pytest -from serinv import _get_module_from_array +from ....conftest import ARRAY_TYPE + +from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize, rhs from serinv.algs import pobtf, pobts +if backend_flags["cupy_avail"]: + import cupyx as cpx + + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) @@ -18,6 +33,7 @@ def test_pobts( array_type: str, dtype: np.dtype, ): + A = dd_bt( diagonal_blocksize, n_diag_blocks, @@ -47,9 +63,22 @@ def test_pobts( _, ) = bt_dense_to_arrays(A, diagonal_blocksize, n_diag_blocks) + if backend_flags["cupy_avail"] and array_type == "streaming": + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :] + B_pinned = cpx.zeros_like_pinned(B) + B_pinned[:, :] = B[:, :] + + A_diagonal_blocks = A_diagonal_blocks_pinned + A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned + B = B_pinned + pobtf( A_diagonal_blocks, A_lower_diagonal_blocks, + device_streaming=True if array_type == "streaming" else False, ) # Forward solve: Y=L^{-1}B @@ -58,6 +87,7 @@ def test_pobts( A_lower_diagonal_blocks, B, trans="N", + device_streaming=True if array_type == "streaming" else False, ) # Backward solve: X=L^{-T}Y @@ -66,6 +96,7 @@ def test_pobts( A_lower_diagonal_blocks, B, trans="C", + device_streaming=True if array_type == "streaming" else False, ) assert xp.allclose(B, X_ref) diff --git a/tests/tests_algs/regular/tests_bt/test_pobtsi.py b/tests/tests_algs/regular/tests_bt/test_pobtsi.py index 22ec1d3a..5463a7b9 100644 --- a/tests/tests_algs/regular/tests_bt/test_pobtsi.py +++ b/tests/tests_algs/regular/tests_bt/test_pobtsi.py @@ -3,6 +3,8 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bt_dense_to_arrays, dd_bt, symmetrize @@ -11,6 +13,16 @@ if backend_flags["cupy_avail"]: import cupyx as cpx + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + + @pytest.fixture(params=ARRAY_TYPE, autouse=True) + def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtsi( diff --git a/tests/tests_algs/regular/tests_bta/test_pobtaf.py b/tests/tests_algs/regular/tests_bta/test_pobtaf.py index a30b9094..b4182f78 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtaf.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtaf.py @@ -3,14 +3,31 @@ import numpy as np import pytest +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize from serinv.algs import pobtaf +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + if backend_flags["cupy_avail"]: import cupyx as cpx +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() def test_pobtaf( diff --git a/tests/tests_algs/regular/tests_bta/test_pobtas.py b/tests/tests_algs/regular/tests_bta/test_pobtas.py index 647a0168..61b81303 100644 --- a/tests/tests_algs/regular/tests_bta/test_pobtas.py +++ b/tests/tests_algs/regular/tests_bta/test_pobtas.py @@ -3,11 +3,28 @@ import numpy as np import pytest -from serinv import _get_module_from_array +from ....conftest import ARRAY_TYPE as ARRAY_TYPE + +from serinv import backend_flags, _get_module_from_array from ....testing_utils import bta_dense_to_arrays, dd_bta, symmetrize, rhs from serinv.algs import pobtaf, pobtas +if backend_flags["cupy_avail"]: + ARRAY_TYPE.extend( + [ + pytest.param("streaming", id="streaming"), + ] + ) + +if backend_flags["cupy_avail"]: + import cupyx as cpx + + +@pytest.fixture(params=ARRAY_TYPE, autouse=True) +def array_type(request: pytest.FixtureRequest) -> str: + return request.param + @pytest.mark.mpi_skip() @pytest.mark.parametrize("n_rhs", [1, 2, 3]) @@ -19,6 +36,7 @@ def test_pobtas( array_type: str, dtype: np.dtype, ): + A = dd_bta( diagonal_blocksize, arrowhead_blocksize, @@ -51,11 +69,30 @@ def test_pobtas( A_arrow_tip_block, ) = bta_dense_to_arrays(A, diagonal_blocksize, arrowhead_blocksize, n_diag_blocks) + if backend_flags["cupy_avail"] and array_type == "streaming": + A_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_diagonal_blocks) + A_diagonal_blocks_pinned[:, :, :] = A_diagonal_blocks[:, :, :] + A_lower_diagonal_blocks_pinned = cpx.zeros_like_pinned(A_lower_diagonal_blocks) + A_lower_diagonal_blocks_pinned[:, :, :] = A_lower_diagonal_blocks[:, :, :] + A_lower_arrow_blocks_pinned = cpx.zeros_like_pinned(A_lower_arrow_blocks) + A_lower_arrow_blocks_pinned[:, :, :] = A_lower_arrow_blocks[:, :, :] + A_arrow_tip_block_pinned = cpx.zeros_like_pinned(A_arrow_tip_block) + A_arrow_tip_block_pinned[:, :] = A_arrow_tip_block[:, :] + B_pinned = cpx.zeros_like_pinned(B) + B_pinned[:, :] = B[:, :] + + A_diagonal_blocks = A_diagonal_blocks_pinned + A_lower_diagonal_blocks = A_lower_diagonal_blocks_pinned + A_lower_arrow_blocks = A_lower_arrow_blocks_pinned + A_arrow_tip_block = A_arrow_tip_block_pinned + B = B_pinned + pobtaf( A_diagonal_blocks, A_lower_diagonal_blocks, A_lower_arrow_blocks, A_arrow_tip_block, + device_streaming=True if array_type == "streaming" else False, ) # Forward solve: Y=L^{-1}B @@ -66,6 +103,7 @@ def test_pobtas( A_arrow_tip_block, B, trans="N", + device_streaming=True if array_type == "streaming" else False, ) # Backward solve: X=L^{-T}Y @@ -76,6 +114,7 @@ def test_pobtas( A_arrow_tip_block, B, trans="C", + device_streaming=True if array_type == "streaming" else False, ) assert xp.allclose(B, X_ref)