diff --git a/setup.py b/setup.py new file mode 100644 index 00000000..90ac7218 --- /dev/null +++ b/setup.py @@ -0,0 +1,32 @@ +""" +from setuptools import setup, Extension +from Cython.Build import cythonize +import os + +CONDA_PREFIX = os.environ.get("CONDA_PREFIX", "") + +CUDA_INCLUDE = os.path.join(CONDA_PREFIX, "targets", "x86_64-linux", "include") + + + +ext = Extension( + name="cupyfix_backends.cuda.libs.cublas", + sources=[ + "src/serinv/cupyfix_backends/cuda/libs/cublas.pyx" + ], + include_dirs=["cupyfix_backends/cuda/libs", + "cupyfix_backends/hip", + "cupyfix_backends/stub", + "cupyfix_backends/cuda", + "cupyfix_backends", + CUDA_INCLUDE + ], + +) + +setup( + name="cupyfix_backends", + ext_modules=cythonize([ext]), + packages=["src/serinv/cupyfix_backends.cuda.libs"], +) +""" \ No newline at end of file diff --git a/src/serinv/__init__.py b/src/serinv/__init__.py index ca7f21dd..bad9c860 100644 --- a/src/serinv/__init__.py +++ b/src/serinv/__init__.py @@ -163,7 +163,7 @@ def _use_nccl(comm): return False -def _get_nccl_parameters(arr, comm, op: str): +def _get_nccl_parameters(arr, comm, rank, op: str): """Get the NCCL parameters for the given operation.""" if np.iscomplexobj(arr): factor = 2 @@ -172,8 +172,8 @@ def _get_nccl_parameters(arr, comm, op: str): if backend_flags["nccl_avail"]: if op == "allgather": - count = (arr.size // comm.size) * factor - displacement = count * comm.rank * arr.dtype.itemsize + count = (arr.size // comm.size()) * factor + displacement = count * rank * (arr.dtype.itemsize // factor) elif op == "allreduce": count = arr.size * factor displacement = 0 diff --git a/src/serinv/algs/pobtaf.py b/src/serinv/algs/pobtaf.py index 05f58d3b..b7fb8a06 100644 --- a/src/serinv/algs/pobtaf.py +++ b/src/serinv/algs/pobtaf.py @@ -7,6 +7,7 @@ _get_cholesky, ) +from serinv.block_primitive import trsm, gemm, syherk def pobtaf( A_diagonal_blocks: ArrayLike, @@ -118,18 +119,18 @@ def _pobtaf( # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T + ) + # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -141,22 +142,34 @@ def _pobtaf( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) - + print(A_diagonal_blocks[i + 1, :, :]) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + A_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) + print(A_arrow_tip_block[:, :]) if factorize_last_block: # L_{ndb, ndb} = chol(A_{ndb, ndb}) @@ -164,7 +177,7 @@ def _pobtaf( # L_{ndb+1, ndb} = A_{ndb+1, ndb} @ L_{ndb, ndb}^{-T} L_lower_arrow_blocks[-1, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[-1, :, :], A_lower_arrow_blocks[-1, :, :].conj().T, lower=True, @@ -174,9 +187,21 @@ def _pobtaf( ) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} + # A_arrow_tip_block[:, :] = ( + # gemm( + # L_lower_arrow_blocks[-1, :, :], + # L_lower_arrow_blocks[-1, :, :], + # A_arrow_tip_block[:, :], + # trans_b='C', alpha=-1.0, beta=1.0 + # ) + # ) + A_arrow_tip_block[:, :] = ( - A_arrow_tip_block[:, :] - - L_lower_arrow_blocks[-1, :, :] @ L_lower_arrow_blocks[-1, :, :].conj().T + syherk( + L_lower_arrow_blocks[-1, :, :], + A_arrow_tip_block[:, :], + alpha=-1.0, beta=1.0, lower=True + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) @@ -210,18 +235,16 @@ def _pobtaf_permuted( # Compute lower factors # L_{i+1, i} = A_{i+1, i} @ L_{i, i}^{-T} L_lower_diagonal_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], - A_lower_diagonal_blocks[i, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks[i, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) # L_{top, i} = A_{top, i} @ U{i, i}^{-1} buffer[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], buffer[i, :, :].conj().T, lower=True, @@ -232,7 +255,7 @@ def _pobtaf_permuted( # L_{ndb+1, i} = A_{ndb+1, i} @ L_{i, i}^{-T} L_lower_arrow_blocks[i, :, :] = ( - la.solve_triangular( + trsm( L_diagonal_blocks[i, :, :], A_lower_arrow_blocks[i, :, :].conj().T, lower=True, @@ -244,40 +267,63 @@ def _pobtaf_permuted( # Update next diagonal block # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks[i + 1, :, :] = ( - A_diagonal_blocks[i + 1, :, :] - - L_lower_diagonal_blocks[i, :, :] - @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_diagonal_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_diagonal_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) - # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks[i + 1, :, :] = ( - A_lower_arrow_blocks[i + 1, :, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_diagonal_blocks[i, :, :], + A_lower_arrow_blocks[i + 1, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T L_arrow_tip_block[:, :] = ( - L_arrow_tip_block[:, :] - - L_lower_arrow_blocks[i, :, :] @ L_lower_arrow_blocks[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + L_lower_arrow_blocks[i, :, :], + L_arrow_tip_block[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_blocks[0, :, :] = ( - A_diagonal_blocks[0, :, :] - buffer[i, :, :] @ buffer[i, :, :].conj().T + gemm( + buffer[i, :, :], + buffer[i, :, :], + A_diagonal_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer[i + 1, :, :] = ( - -buffer[i, :, :] @ L_lower_diagonal_blocks[i, :, :].conj().T + gemm( + buffer[i, :, :], + L_lower_diagonal_blocks[i, :, :], + trans_b='C', alpha=-1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_lower_arrow_blocks[0, :, :] = ( - A_lower_arrow_blocks[0, :, :] - - L_lower_arrow_blocks[i, :, :] @ buffer[i, :, :].conj().T + gemm( + L_lower_arrow_blocks[i, :, :], + buffer[i, :, :], + A_lower_arrow_blocks[0, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) @@ -391,13 +437,11 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) compute_lower_events[i % 2].record(stream=compute_stream) @@ -418,7 +462,7 @@ def _pobtaf_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -446,24 +490,33 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_lower_h2d_events[i % 2].record(stream=compute_stream) # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) compute_arrow_h2d_events[i % 2].record(stream=compute_stream) @@ -488,7 +541,7 @@ def _pobtaf_streaming( compute_stream.wait_event(h2d_arrow_events[(n_diag_blocks - 1) % 2]) if factorize_last_block: L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[(n_diag_blocks - 1) % 2, :, :], A_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T, lower=True, @@ -509,9 +562,12 @@ def _pobtaf_streaming( if factorize_last_block: # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, ndb} @ L_{ndb+1, ndb}^{T} A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :] - @ L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + L_lower_arrow_blocks_d[(n_diag_blocks - 1) % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # L_{ndb+1, ndb+1} = chol(A_{ndb+1, ndb+1}) @@ -654,13 +710,11 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_lower_events[i % 2]) L_lower_diagonal_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], - A_lower_diagonal_blocks_d[i % 2, :, :].conj().T, - lower=True, + A_lower_diagonal_blocks_d[i % 2, :, :], + trans='C',lower=True, side=1 ) - .conj() - .T ) cp_lower_events[i % 2].record(stream=compute_stream) @@ -681,7 +735,7 @@ def _pobtaf_permuted_streaming( with compute_stream: compute_stream.wait_event(h2d_arrow_events[i % 2]) L_lower_arrow_blocks_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], A_lower_arrow_blocks_d[i % 2, :, :].conj().T, lower=True, @@ -701,7 +755,7 @@ def _pobtaf_permuted_streaming( # L_{top, i} = A_{top, i} @ U{i, i}^{-1} with compute_stream: L_upper_nested_dissection_buffer_d[i % 2, :, :] = ( - cu_la.solve_triangular( + trsm( L_diagonal_blocks_d[i % 2, :, :], buffer_d[i % 2, :, :].conj().T, lower=True, @@ -732,48 +786,66 @@ def _pobtaf_permuted_streaming( compute_stream.wait_event(h2d_diagonal_events[(i + 1) % 2]) # A_{i+1, i+1} = A_{i+1, i+1} - L_{i+1, i} @ L_{i+1, i}.conj().T A_diagonal_blocks_d[(i + 1) % 2, :, :] = ( - A_diagonal_blocks_d[(i + 1) % 2, :, :] - - L_lower_diagonal_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_diagonal_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_diagonal_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{ndb+1, i+1} = A_{ndb+1, i+1} - L_{ndb+1, i} @ L_{i+1, i}.conj().T A_lower_arrow_blocks_d[(i + 1) % 2, :, :] = ( - A_lower_arrow_blocks_d[(i + 1) % 2, :, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + A_lower_arrow_blocks_d[(i + 1) % 2, :, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # A_{top, i+1} = - L{top, i} @ L_{i+1, i}.conj().T buffer_d[(i + 1) % 2, :, :] = ( - -L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_lower_diagonal_blocks_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_lower_diagonal_blocks_d[i % 2, :, :], + trans_b='C', alpha=-1.0 + ) ) cp_lower_events_h2d_release[i % 2].record(stream=compute_stream) # Update the block at the tip of the arrowhead # A_{ndb+1, ndb+1} = A_{ndb+1, ndb+1} - L_{ndb+1, i} @ L_{ndb+1, i}.conj().T A_arrow_tip_block_d[:, :] = ( - A_arrow_tip_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_lower_arrow_blocks_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_lower_arrow_blocks_d[i % 2, :, :], + A_arrow_tip_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # Update the top (first blocks) of the arrowhead # A_{ndb+1, top} = A_{ndb+1, top} - L_{ndb+1, i} @ L_{top, i}.conj().T A_arrow_bottom_top_block_d[:, :] = ( - A_arrow_bottom_top_block_d[:, :] - - L_lower_arrow_blocks_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_lower_arrow_blocks_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_arrow_bottom_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) cp_arrow_events_h2d_release[i % 2].record(stream=compute_stream) # Update top and next upper/lower blocks of 2-sided factorization pattern # A_{top, top} = A_{top, top} - L_{top, i} @ L_{top, i}.conj().T A_diagonal_top_block_d[:, :] = ( - A_diagonal_top_block_d[:, :] - - L_upper_nested_dissection_buffer_d[i % 2, :, :] - @ L_upper_nested_dissection_buffer_d[i % 2, :, :].conj().T + gemm( + L_upper_nested_dissection_buffer_d[i % 2, :, :], + L_upper_nested_dissection_buffer_d[i % 2, :, :], + A_diagonal_top_block_d[:, :], + trans_b='C', alpha=-1.0, beta=1.0 + ) ) # --- Device 2 Host transfers --- diff --git a/src/serinv/algs/pobtas.py b/src/serinv/algs/pobtas.py index bab2a911..dbc2d916 100644 --- a/src/serinv/algs/pobtas.py +++ b/src/serinv/algs/pobtas.py @@ -4,6 +4,7 @@ from serinv import ( ArrayLike, _get_module_from_array, + ) @@ -83,7 +84,7 @@ def _pobtas( B[i * diag_blocksize : (i + 1) * diag_blocksize] = la.solve_triangular( L_diagonal_blocks[i], B[i * diag_blocksize : (i + 1) * diag_blocksize], - lower=True, + lower=True ) B[(i + 1) * diag_blocksize : (i + 2) * diag_blocksize] -= ( diff --git a/src/serinv/block_primitive/__init__.py b/src/serinv/block_primitive/__init__.py new file mode 100644 index 00000000..941b6c92 --- /dev/null +++ b/src/serinv/block_primitive/__init__.py @@ -0,0 +1,11 @@ +from serinv.block_primitive.gemm import gemm +from serinv.block_primitive.trsm import trsm +from serinv.block_primitive.syherk import syherk +import cupyfix_backends + +__all__ = [ + "gemm", + "trsm", + "syherk", + cupyfix_backends +] \ No newline at end of file diff --git a/src/serinv/block_primitive/gemm.py b/src/serinv/block_primitive/gemm.py new file mode 100644 index 00000000..8924cb2b --- /dev/null +++ b/src/serinv/block_primitive/gemm.py @@ -0,0 +1,229 @@ +from serinv import _get_module_from_array + +import numpy as np +from numpy.linalg import matmul + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy_backends.cuda.libs import cublas + from cupy import _core + from cupy.cuda import device +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def gemm (a, b, c=None, alpha=1.0, beta=0.0, trans_a ='N', trans_b ='N'): + """Wrapper to call GeMM for host or device""" + xp, la = _get_module_from_array(a) + + if xp == np: + return matmul_gemm_host(a, b, alpha, beta, c, trans_a, trans_b) + elif xp == cp: + return matmul_gemm_device(trans_a, trans_b, a, b, c, alpha, beta) + else: + ModuleNotFoundError("Unknown Module") + + +def matmul_gemm_host(a, b, alpha=1.0, beta=0.0, c=None, trans_a=0, trans_b=0, overwrite_c=0, check_finite=False): + """Computes out = alpha * op(a) @ op(b) + beta * out + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'C'. + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + if c is None: + c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) + + if not trans_a and not trans_b: + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + elif trans_a and not trans_b: + if a1.shape[1] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + elif not trans_a and trans_b: + if a1.shape[0] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + else: + if a1.shape[1] != b1.shape[1]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + if beta != 0 and c1 is None: + raise ValueError('expected C matrix') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = matmul_gemm_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + if c1 is not None: + overwrite_c = overwrite_c or _datacopied(c1, c) + + x = _matmul_gemm(a1, b1, alpha, beta, c1, trans_a, trans_b, overwrite_c) + return x + + +# gemm without the input validation +def _matmul_gemm(a1, b1, alpha=1.0, beta=0.0, c1=None, trans_a=0, trans_b=0, overwrite_c=0): + + trans_a = {'N': 0, 'T': 1, 'C': 2}.get(trans_a, trans_a) + trans_b = {'N': 0, 'T': 1, 'C': 2}.get(trans_b, trans_b) + gemm, = get_blas_funcs(('gemm',), (a1, b1)) + + if beta == 0: + out = gemm(alpha, a1, b1, beta=beta, trans_a=trans_a, trans_b=trans_b, overwrite_c=overwrite_c) + else: + out = gemm(alpha, a1, b1, beta, c1, trans_a, trans_b, overwrite_c) + + + return out + + +# Util functions for cupy gemm +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + + +def _change_order_if_necessary(a, lda): + if lda is None: + lda = a.shape[0] + if not a._f_contiguous: + a = a.copy(order='F') + return a, lda + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr +# Util functions for cupy gemm end + + +def matmul_gemm_device(transa, transb, a, b, out=None, alpha=1.0, beta=0.0): + """Computes out = alpha * op(a) @ op(b) + beta * out + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + op(b) = b if transb is 'N', op(b) = b.T if transb is 'T', + op(b) = b.T.conj() if transb is 'C'. + """ + assert a.ndim == b.ndim == 2 + assert a.dtype == b.dtype + dtype = a.dtype.char + if dtype == 'f': + func = cublas.sgemm + elif dtype == 'd': + func = cublas.dgemm + elif dtype == 'F': + func = cublas.cgemm + elif dtype == 'D': + func = cublas.zgemm + else: + raise TypeError('invalid dtype') + + + transa = _trans_to_cublas_op(transa) + transb = _trans_to_cublas_op(transb) + if transa == cublas.CUBLAS_OP_N: + m, k = a.shape + else: + k, m = a.shape + if transb == cublas.CUBLAS_OP_N: + n = b.shape[1] + assert b.shape[0] == k + else: + n = b.shape[0] + assert b.shape[1] == k + if out is None: + out = cp.empty((m, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (m, n) + assert out.dtype == dtype + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, transa = _decide_ld_and_trans(a, transa) + ldb, transb = _decide_ld_and_trans(b, transb) + if not (lda is None or ldb is None): + if out._f_contiguous: + try: + func(handle, transa, transb, m, n, k, alpha_ptr, + a.data.ptr, lda, b.data.ptr, ldb, beta_ptr, out.data.ptr, + m) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + elif out._c_contiguous: + # Computes out.T = alpha * b.T @ a.T + beta * out.T + try: + func(handle, 1 - transb, 1 - transa, n, m, k, alpha_ptr, + b.data.ptr, ldb, a.data.ptr, lda, beta_ptr, out.data.ptr, + n) + finally: + cublas.setPointerMode(handle, orig_mode) + return out + + a, lda = _change_order_if_necessary(a, lda) + b, ldb = _change_order_if_necessary(b, ldb) + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, transa, transb, m, n, k, alpha_ptr, a.data.ptr, lda, + b.data.ptr, ldb, beta_ptr, c.data.ptr, m) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + _core.elementwise_copy(c, out) + return out \ No newline at end of file diff --git a/src/serinv/block_primitive/syherk.py b/src/serinv/block_primitive/syherk.py new file mode 100644 index 00000000..d39dc2d3 --- /dev/null +++ b/src/serinv/block_primitive/syherk.py @@ -0,0 +1,191 @@ +from serinv import _get_module_from_array + +import serinv.block_primitive.trymod + + +import numpy as np +from numpy.linalg import matmul + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy_backends.cuda.libs import cublas + from cupy.cuda import device + + +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def syherk(a, c=None, alpha=1.0, beta=0.0, trans=0, lower = False): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp, la = _get_module_from_array(a) + if xp == np: + return matmul_syherk_host(a, c, alpha, beta, trans, lower) + elif xp == cp: + return matmul_syherk_device(a, trans, c, alpha, beta, lower) + else: + ModuleNotFoundError("Unknown Module") + +def matmul_syherk_host(a, c=None, alpha=1.0, beta=1.0, trans=0, lower=False, unit_diagonal=False, + overwrite_c=False, check_finite=True, side=0): + """Computes out = alpha * op(a) @ op(a)^T + beta * b + + op(a) = a if transa is 'N', op(a) = a.T if transa is 'T', + op(a) = a.T.conj() if transa is 'C'. + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + if c is None: + c1 = None + else: + c1 = _asarray_validated(c, check_finite=check_finite) + + overwrite_c = overwrite_c or _datacopied(c1, c) + + x = _syherk(a1, c1, alpha, beta, trans, lower, overwrite_c) + return x + + +# syherk without the input validation +def _syherk(a1, c1=None, alpha=1.0, beta=0.0, trans=0, lower=False, + overwrite_c=False): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + + if np.iscomplexobj(a1): + syherk = get_blas_funcs(('herk'), (a1, c1)) + else: + syherk = get_blas_funcs(('syrk'), (a1, c1)) + + out = syherk(alpha, a1, beta, c1, trans, lower, overwrite_c) + + return out + + + +# Util functions for cupy gemm +def _trans_to_cublas_op(trans): + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + trans = cublas.CUBLAS_OP_N + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + trans = cublas.CUBLAS_OP_T + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + trans = cublas.CUBLAS_OP_C + else: + raise TypeError('invalid trans (actual: {})'.format(trans)) + return trans + +def _decide_ld_and_trans(a, trans): + ld = None + if trans in (cublas.CUBLAS_OP_N, cublas.CUBLAS_OP_T): + if a._f_contiguous: + ld = a.shape[0] + elif a._c_contiguous: + ld = a.shape[1] + trans = 1 - trans + return ld, trans + +def _get_scalar_ptr(a, dtype): + if isinstance(a, cp.ndarray): + if a.dtype != dtype: + a = cp.array(a, dtype=dtype) + a_ptr = a.data.ptr + else: + if not (isinstance(a, np.ndarray) and a.dtype == dtype): + a = np.array(a, dtype=dtype) + a_ptr = a.ctypes.data + return a, a_ptr +# Util functions for cupy gemm end + +def matmul_syherk_device(a, trans='N', out=None, alpha=1.0, beta=0.0, lower=False): + """Computes out := alpha*op1(a)*op2(a) + beta*out + + op1(a) = a if trans is 'N', op2(a) = a.T if transa is 'N' + op1(a) = a.T if trans is 'T', op2(a) = a if transa is 'T' + lower specifies whether the upper or lower triangular + part of the array out is to be referenced + """ + assert a.ndim == 2 + dtype = a.dtype.char + if dtype == 'f': + func = cublas.ssyrk + elif dtype == 'd': + func = cublas.dsyrk + elif dtype == 'F': + func = cublas.cherk + elif dtype == 'D': + func = cublas.zherk + else: + raise TypeError('invalid dtype') + + trans = _trans_to_cublas_op(trans) + if trans == cublas.CUBLAS_OP_N: + n, k = a.shape + else: + k, n = a.shape + if out is None: + out = cp.zeros((n, n), dtype=dtype, order='F') + beta = 0.0 + else: + assert out.ndim == 2 + assert out.shape == (n, n) + assert out.dtype == dtype + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + alpha, alpha_ptr = _get_scalar_ptr(alpha, a.dtype) + beta, beta_ptr = _get_scalar_ptr(beta, a.dtype) + handle = device.get_cublas_handle() + orig_mode = cublas.getPointerMode(handle) + if isinstance(alpha, cp.ndarray) or isinstance(beta, cp.ndarray): + if not isinstance(alpha, cp.ndarray): + alpha = cp.array(alpha) + alpha_ptr = alpha.data.ptr + if not isinstance(beta, cp.ndarray): + beta = cp.array(beta) + beta_ptr = beta.data.ptr + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_DEVICE) + else: + cublas.setPointerMode(handle, cublas.CUBLAS_POINTER_MODE_HOST) + + lda, trans = _decide_ld_and_trans(a, trans) + ldo, _ = _decide_ld_and_trans(out, trans) + if out._c_contiguous: + if not a._c_contiguous: + a = a.copy(order='C') + trans = 1 - trans + lda = a.shape[1] + try: + func(handle, 1 - uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + + else: + if not a._f_contiguous: + a = a.copy(order='F') + lda = a.shape[0] + trans = 1 - trans + c = out + if not out._f_contiguous: + c = out.copy(order='F') + try: + func(handle, uplo, trans, n, k, + alpha_ptr, a.data.ptr, lda, + beta_ptr, out.data.ptr, ldo) + finally: + cublas.setPointerMode(handle, orig_mode) + if not out._f_contiguous: + out[...] = c + return out \ No newline at end of file diff --git a/src/serinv/block_primitive/trsm.py b/src/serinv/block_primitive/trsm.py new file mode 100644 index 00000000..0185bed9 --- /dev/null +++ b/src/serinv/block_primitive/trsm.py @@ -0,0 +1,319 @@ +import numpy as np + +from serinv import _get_module_from_array + +from scipy.linalg.blas import get_blas_funcs +from scipy.linalg._misc import _datacopied +from scipy.linalg._decomp import _asarray_validated + +try: + import cupy as cp + from cupy.cuda import cublas + from cupy.cuda import device + from cupy.linalg import _util +except (ImportError, ImportWarning, ModuleNotFoundError): + pass + +def trsm(a, b, trans=0, lower = False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Wrapper for the trsm function to call depending on wheter the solve happens on the host or the device + + For Compatibility this function accepts exactly the same parameters as what the scipy and cupy implementations accept + plus the side parameter which can either be 0 or 1 for left or right hand side + """ + xp, la = _get_module_from_array(a) + if xp == np: + return solve_triangular_host(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + elif xp == cp: + return solve_triangular_device(a, b, trans, lower, unit_diagonal, overwrite_b, check_finite, side) + else: + ModuleNotFoundError("Unknown Module") + + + +def solve_triangular_device(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=False, side=0): + """Solve the equation a x = b for x, assuming a is a triangular matrix. + + Args: + a (cupy.ndarray): The matrix with dimension ``(M, M)``. + b (cupy.ndarray): The matrix with dimension ``(M,)`` or + ``(M, N)``. + lower (bool): Use only data contained in the lower triangle of ``a``. + Default is to use upper triangle. + trans (0, 1, 2, 'N', 'T' or 'C'): Type of system to solve: + + - *'0'* or *'N'* -- :math:`a x = b` + - *'1'* or *'T'* -- :math:`a^T x = b` + - *'2'* or *'C'* -- :math:`a^H x = b` + + unit_diagonal (bool): If ``True``, diagonal elements of ``a`` are + assumed to be 1 and will not be referenced. + overwrite_b (bool): Allow overwriting data in b (may enhance + performance) + check_finite (bool): Whether to check that the input matrices contain + only finite numbers. Disabling may give a performance gain, but may + result in problems (crashes, non-termination) if the inputs do + contain infinities or NaNs. + + Returns: + cupy.ndarray: + The matrix with dimension ``(M,)`` or ``(M, N)``. + + .. seealso:: :func:`scipy.linalg.solve_triangular` + """ + + _util._assert_cupy_array(a, b) + + if a.ndim == 2: + if a.shape[0] != a.shape[1]: + raise ValueError('expected square matrix') + if len(a) != len(b): + raise ValueError('incompatible dimensions') + batch_count = 0 + elif a.ndim > 2: + if a.shape[-1] != a.shape[-2]: + raise ValueError('expected a batch of square matrices') + if a.shape[:-2] != b.shape[:a.ndim - 2]: + raise ValueError('incompatible batch count') + if b.ndim < a.ndim - 1 or a.shape[-2] != b.shape[a.ndim - 2]: + raise ValueError('incompatible dimensions') + batch_count = math.prod(a.shape[:-2]) + else: + raise ValueError( + 'expected one square matrix or a batch of square matrices') + + # Cast to float32 or float64 + if a.dtype.char in 'fd': + dtype = a.dtype + else: + dtype = np.promote_types(a.dtype.char, 'f') + + if check_finite: + if a.dtype.kind == 'f' and not cp.isfinite(a).all(): + raise ValueError( + 'array must not contain infs or NaNs') + if b.dtype.kind == 'f' and not cp.isfinite(b).all(): + raise ValueError( + 'array must not contain infs or NaNs') + + if batch_count: + m, n = b.shape[-2:] if b.ndim == a.ndim else (b.shape[-1], 1) + + a_new_shape = (batch_count, m, m) + b_shape = b.shape + b_data_ptr = b.data.ptr + # trsm receives Fortran array, but we want zero copy + if trans == 'N' or trans == cublas.CUBLAS_OP_N: + # normal Fortran upper == transpose C lower + trans = cublas.CUBLAS_OP_T + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'T' or trans == cublas.CUBLAS_OP_T: + # transpose Fortran upper == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), dtype=dtype) + elif trans == 'C' or trans == cublas.CUBLAS_OP_C: + if dtype == 'f' or dtype == 'd': + # real numbers + # Hermitian Fortran upper == transpose Fortran upper + # == normal C lower + trans = cublas.CUBLAS_OP_N + lower = not lower + a = cp.ascontiguousarray(a.reshape(*a_new_shape), + dtype=dtype) + else: + # complex numbers + trans = cublas.CUBLAS_OP_C + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + else: # know nothing about `trans`, just convert C to Fortran + a = cp.ascontiguousarray( + a.reshape(*a_new_shape).transpose(0, 2, 1), dtype=dtype) + b = cp.ascontiguousarray( + b.reshape(batch_count, m, n).transpose(0, 2, 1), dtype=dtype) + if b.data.ptr == b_data_ptr and not overwrite_b: + b = b.copy() + + start = a.data.ptr + step = m * m * a.itemsize + stop = start + step * batch_count + a_array = cp.arange(start, stop, step, dtype=cp.uintp) + + start = b.data.ptr + step = m * n * b.itemsize + stop = start + step * batch_count + b_array = cp.arange(start, stop, step, dtype=cp.uintp) + else: + a = cp.array(a, dtype=dtype, order='F', copy=None) + b = cp.array(b, dtype=dtype, order='F', + copy=(None if overwrite_b else True)) + + m, n = (b.size, 1) if b.ndim == 1 else b.shape + + if trans == 'N': + trans = cublas.CUBLAS_OP_N + elif trans == 'T': + trans = cublas.CUBLAS_OP_T + elif trans == 'C': + trans = cublas.CUBLAS_OP_C + + cublas_handle = device.get_cublas_handle() + one = np.array(1, dtype=dtype) + + if lower: + uplo = cublas.CUBLAS_FILL_MODE_LOWER + else: + uplo = cublas.CUBLAS_FILL_MODE_UPPER + + if unit_diagonal: + diag = cublas.CUBLAS_DIAG_UNIT + else: + diag = cublas.CUBLAS_DIAG_NON_UNIT + + if side: + side = cublas.CUBLAS_SIDE_RIGHT + else: + side = cublas.CUBLAS_SIDE_LEFT + + if batch_count: + if dtype == 'f': + trsm = cublas.strsmBatched + elif dtype == 'd': + trsm = cublas.dtrsmBatched + elif dtype == 'F': + trsm = cublas.ctrsmBatched + else: # dtype == 'D' + trsm = cublas.ztrsmBatched + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a_array.data.ptr, m, + b_array.data.ptr, m, batch_count) + return b.transpose(0, 2, 1).reshape(b_shape) + else: + if dtype == 'f': + trsm = cublas.strsm + elif dtype == 'd': + trsm = cublas.dtrsm + elif dtype == 'F': + trsm = cublas.ctrsm + else: # dtype == 'D' + trsm = cublas.ztrsm + trsm( + cublas_handle, side, uplo, + trans, diag, + m, n, one.ctypes.data, a.data.ptr, m, b.data.ptr, m) + return b + + +def solve_triangular_host(a, b, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, check_finite=True, side=0): + """ + Solve the equation ``a x = b`` for `x`, assuming a is a triangular matrix. + + Parameters + ---------- + a : (M, M) array_like + A triangular matrix + b : (M,) or (M, N) array_like + Right-hand side matrix in ``a x = b`` + lower : bool, optional + Use only data contained in the lower triangle of `a`. + Default is to use upper triangle. + trans : {0, 1, 2, 'N', 'T', 'C'}, optional + Type of system to solve: + + ======== ========= + trans system + ======== ========= + 0 or 'N' a x = b + 1 or 'T' a^T x = b + 2 or 'C' a^H x = b + ======== ========= + unit_diagonal : bool, optional + If True, diagonal elements of `a` are assumed to be 1 and + will not be referenced. + overwrite_b : bool, optional + Allow overwriting data in `b` (may enhance performance) + check_finite : bool, optional + Whether to check that the input matrices contain only finite numbers. + Disabling may give a performance gain, but may result in problems + (crashes, non-termination) if the inputs do contain infinities or NaNs. + + Returns + ------- + x : (M,) or (M, N) ndarray + Solution to the system ``a x = b``. Shape of return matches `b`. + + Raises + ------ + LinAlgError + If `a` is singular + + Notes + ----- + .. versionadded:: 0.9.0 + + Examples + -------- + Solve the lower triangular system a x = b, where:: + + [3 0 0 0] [4] + a = [2 1 0 0] b = [2] + [1 0 1 0] [4] + [1 1 1 1] [2] + + >>> import numpy as np + >>> from scipy.linalg import solve_triangular + >>> a = np.array([[3, 0, 0, 0], [2, 1, 0, 0], [1, 0, 1, 0], [1, 1, 1, 1]]) + >>> b = np.array([4, 2, 4, 2]) + >>> x = solve_triangular(a, b, lower=True) + >>> x + array([ 1.33333333, -0.66666667, 2.66666667, -1.33333333]) + >>> a.dot(x) # Check the result + array([ 4., 2., 4., 2.]) + + """ + + a1 = _asarray_validated(a, check_finite=check_finite) + b1 = _asarray_validated(b, check_finite=check_finite) + + if len(a1.shape) != 2 or a1.shape[0] != a1.shape[1]: + raise ValueError('expected square matrix') + + if a1.shape[0] != b1.shape[0]: + raise ValueError(f'shapes of a {a1.shape} and b {b1.shape} are incompatible') + + # accommodate empty arrays + if b1.size == 0: + dt_nonempty = solve_triangular_host( + np.eye(2, dtype=a1.dtype), np.ones(2, dtype=b1.dtype) + ).dtype + return np.empty_like(b1, dtype=dt_nonempty) + + overwrite_b = overwrite_b or _datacopied(b1, b) + + x = _solve_triangular(a1, b1, trans, lower, unit_diagonal, overwrite_b, side) + return x + + +# solve_triangular without the input validation +def _solve_triangular(a1, b1, trans=0, lower=False, unit_diagonal=False, + overwrite_b=False, side=0): + + trans = {'N': 0, 'T': 1, 'C': 2}.get(trans, trans) + trsm, = get_blas_funcs(('trsm',), (a1, b1)) + + alpha = 1.0 + + if a1.flags.f_contiguous or trans == 2: + x = trsm(alpha, a1, b1, overwrite_b=overwrite_b, lower=lower, + trans_a=trans, diag=unit_diagonal, side=side) + else: + # transposed system is solved since trsm expects Fortran ordering + x = trsm(alpha, a1.T, b1, overwrite_b=overwrite_b, lower=not lower, + trans_a=not trans, diag=unit_diagonal, side=side) + + return x \ No newline at end of file diff --git a/src/serinv/block_primitive/trymod/__init__.py b/src/serinv/block_primitive/trymod/__init__.py new file mode 100644 index 00000000..d167fa53 --- /dev/null +++ b/src/serinv/block_primitive/trymod/__init__.py @@ -0,0 +1,6 @@ +def foo(): + return 0 + +__all__ = [ + "foo", +] \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/__init__.pxd b/src/serinv/cupyfix_backends/__init__.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/__init__.py b/src/serinv/cupyfix_backends/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/serinv/cupyfix_backends/__init__.py @@ -0,0 +1 @@ + diff --git a/src/serinv/cupyfix_backends/cuda/__init__.pxd b/src/serinv/cupyfix_backends/cuda/__init__.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/cuda/__init__.py b/src/serinv/cupyfix_backends/cuda/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/cuda/cupy_cublas.h b/src/serinv/cupyfix_backends/cuda/cupy_cublas.h new file mode 100644 index 00000000..c3d4874b --- /dev/null +++ b/src/serinv/cupyfix_backends/cuda/cupy_cublas.h @@ -0,0 +1,24 @@ +#ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H +#define INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H + +#include +#include + +#if CUDA_VERSION >= 11000 + +#define cublasGemmEx_v11 cublasGemmEx +#define cublasGemmStridedBatchedEx_v11 cublasGemmStridedBatchedEx + +#else + +typedef enum{} cublasComputeType_t; +cublasStatus_t cublasGemmEx_v11(...) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} +cublasStatus_t cublasGemmStridedBatchedEx_v11(...) { + return CUBLAS_STATUS_NOT_SUPPORTED; +} + +#endif // if CUDA_VERSION >= 11000 + +#endif // #ifndef INCLUDE_GUARD_CUDA_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/cuda/libs/__init.pxd b/src/serinv/cupyfix_backends/cuda/libs/__init.pxd new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/cuda/libs/__init__.py b/src/serinv/cupyfix_backends/cuda/libs/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pxd b/src/serinv/cupyfix_backends/cuda/libs/cublas.pxd new file mode 100644 index 00000000..213630de --- /dev/null +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pxd @@ -0,0 +1,34 @@ +"""Thin wrapper of CUBLAS.""" +from libc.stdint cimport intptr_t + + +############################################################################### +# Types +############################################################################### + +cdef extern from *: + ctypedef void* cuComplexPtr 'cuComplex*' + ctypedef void* cuDoubleComplexPtr 'cuDoubleComplex*' + + +cdef extern from *: + ctypedef void* Handle 'cublasHandle_t' + + ctypedef int DiagType 'cublasDiagType_t' + ctypedef int FillMode 'cublasFillMode_t' + ctypedef int Operation 'cublasOperation_t' + ctypedef int PointerMode 'cublasPointerMode_t' + ctypedef int SideMode 'cublasSideMode_t' + ctypedef int GemmAlgo 'cublasGemmAlgo_t' + ctypedef int Math 'cublasMath_t' + ctypedef int ComputeType 'cublasComputeType_t' + +############################################################################### +# BLAS Level 3 +############################################################################### + + +cpdef cherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc) +cpdef zherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc) diff --git a/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx new file mode 100644 index 00000000..c2dcf3cb --- /dev/null +++ b/src/serinv/cupyfix_backends/cuda/libs/cublas.pyx @@ -0,0 +1,136 @@ +# distutils: language = c++ + +"""Thin wrapper of CUBLAS.""" + +cimport cython # NOQA + +from libc.stdint cimport intptr_t + +from cupy_backends.cuda.api import runtime +from cupy_backends.cuda import stream as stream_module + +############################################################################### +# Extern +############################################################################### + +cdef extern from '../../cupy_complex.h': + ctypedef struct cuComplex 'cuComplex': + float x, y + + ctypedef struct cuDoubleComplex 'cuDoubleComplex': + double x, y + +cdef extern from '../../cupy_blas.h' nogil: + ctypedef void* Stream 'cudaStream_t' + ctypedef int DataType 'cudaDataType' + + # Stream + int cublasSetStream(Handle handle, Stream streamId) + int cublasGetStream(Handle handle, Stream* streamId) + + # BLAS Level 3 + int cublasCherk( + Handle handle, FillMode uplo, Operation trans, int n, int k, + cuComplex* alpha, cuComplex* A, int lda, + cuComplex* beta, cuComplex* C, int ldc) + int cublasZherk( + Handle handle, FillMode uplo, Operation trans, int n, int k, + cuDoubleComplex* alpha, cuDoubleComplex* A, int lda, + cuDoubleComplex* beta, cuDoubleComplex* C, int ldc) + +############################################################################### +# Error handling +############################################################################### + +cdef dict STATUS = { + 0: 'CUBLAS_STATUS_SUCCESS', + 1: 'CUBLAS_STATUS_NOT_INITIALIZED', + 3: 'CUBLAS_STATUS_ALLOC_FAILED', + 7: 'CUBLAS_STATUS_INVALID_VALUE', + 8: 'CUBLAS_STATUS_ARCH_MISMATCH', + 11: 'CUBLAS_STATUS_MAPPING_ERROR', + 13: 'CUBLAS_STATUS_EXECUTION_FAILED', + 14: 'CUBLAS_STATUS_INTERNAL_ERROR', + 15: 'CUBLAS_STATUS_NOT_SUPPORTED', + 16: 'CUBLAS_STATUS_LICENSE_ERROR', +} + + +cdef dict HIP_STATUS = { + 0: 'HIPBLAS_STATUS_SUCCESS', + 1: 'HIPBLAS_STATUS_NOT_INITIALIZED', + 2: 'HIPBLAS_STATUS_ALLOC_FAILED', + 3: 'HIPBLAS_STATUS_INVALID_VALUE', + 4: 'HIPBLAS_STATUS_MAPPING_ERROR', + 5: 'HIPBLAS_STATUS_EXECUTION_FAILED', + 6: 'HIPBLAS_STATUS_INTERNAL_ERROR', + 7: 'HIPBLAS_STATUS_NOT_SUPPORTED', + 8: 'HIPBLAS_STATUS_ARCH_MISMATCH', + 9: 'HIPBLAS_STATUS_HANDLE_IS_NULLPTR', +} + + +class CUBLASError(RuntimeError): + + def __init__(self, status): + self.status = status + cdef str err + if runtime._is_hip_environment: + err = HIP_STATUS[status] + else: + err = STATUS[status] + super(CUBLASError, self).__init__(err) + + def __reduce__(self): + return (type(self), (self.status,)) + + +@cython.profile(False) +cpdef inline check_status(int status): + if status != 0: + raise CUBLASError(status) + + +cpdef setStream(intptr_t handle, size_t stream): + # TODO(leofang): It seems most of cuBLAS APIs support stream capture (as of + # CUDA 11.5) under certain conditions, see + # https://docs.nvidia.com/cuda/cublas/index.html#CUDA-graphs + # Before we come up with a robust strategy to test the support conditions, + # we disable this functionality. + if not runtime._is_hip_environment and runtime.streamIsCapturing(stream): + raise NotImplementedError( + 'calling cuBLAS API during stream capture is currently ' + 'unsupported') + + with nogil: + status = cublasSetStream(handle, stream) + check_status(status) + +cdef _setStream(intptr_t handle): + """Set current stream""" + setStream(handle, stream_module.get_current_stream_ptr()) + +############################################################################### +# BLAS Level 3 +############################################################################### + +cpdef cherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc): + _setStream(handle) + with nogil: + status = cublasCherk( + handle, uplo, trans, n, k, + alpha, A, lda, + beta, C, ldc) + check_status(status) + + +cpdef zherk(intptr_t handle, int uplo, int trans, int n, int k, + size_t alpha, size_t A, int lda, size_t beta, size_t C, int ldc): + _setStream(handle) + with nogil: + status = cublasZherk( + handle, uplo, trans, n, k, + alpha, A, lda, + beta, C, ldc) + check_status(status) \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/cupy_blas.h b/src/serinv/cupyfix_backends/cupy_blas.h new file mode 100644 index 00000000..0f9006df --- /dev/null +++ b/src/serinv/cupyfix_backends/cupy_blas.h @@ -0,0 +1,7 @@ +#ifndef INCLUDE_GUARD_CUPY_CUBLAS_H +#define INCLUDE_GUARD_CUPY_CUBLAS_H + +#include "cuda/cupy_cublas.h" + + +#endif // #ifndef INCLUDE_GUARD_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/cupy_complex.h b/src/serinv/cupyfix_backends/cupy_complex.h new file mode 100644 index 00000000..09589cde --- /dev/null +++ b/src/serinv/cupyfix_backends/cupy_complex.h @@ -0,0 +1,9 @@ +#ifndef INCLUDE_GUARD_CUPY_COMPLEX_H +#define INCLUDE_GUARD_CUPY_COMPLEX_H + + + +#include + + +#endif // #ifndef INCLUDE_GUARD_CUPY_COMPLEX_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/hip/cupy_cuComplex.h b/src/serinv/cupyfix_backends/hip/cupy_cuComplex.h new file mode 100644 index 00000000..dfb6006d --- /dev/null +++ b/src/serinv/cupyfix_backends/hip/cupy_cuComplex.h @@ -0,0 +1,20 @@ +#ifndef INCLUDE_GUARD_HIP_CUPY_COMPLEX_H +#define INCLUDE_GUARD_HIP_CUPY_COMPLEX_H + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// cuComplex.h +/////////////////////////////////////////////////////////////////////////////// + +struct cuComplex{ + float x, y; +}; + +struct cuDoubleComplex{ + double x, y; +}; + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_HIP_CUPY_COMPLEX_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/hip/cupy_hip_common.h b/src/serinv/cupyfix_backends/hip/cupy_hip_common.h new file mode 100644 index 00000000..d1732ddc --- /dev/null +++ b/src/serinv/cupyfix_backends/hip/cupy_hip_common.h @@ -0,0 +1,161 @@ +#ifndef INCLUDE_GUARD_HIP_CUPY_COMMON_H +#define INCLUDE_GUARD_HIP_CUPY_COMMON_H + +#include +#include +#include + +#define CUDA_VERSION 0 + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// cuda.h +/////////////////////////////////////////////////////////////////////////////// + +typedef int CUdevice; +typedef hipError_t CUresult; +// Conditionally define CUDA_SUCCESS only if it's not defined +#ifndef CUDA_SUCCESS +const CUresult CUDA_SUCCESS = static_cast(0); +#endif +enum CUjit_option {}; +enum CUjitInputType {}; +enum CUarray_format {}; +enum CUaddress_mode {}; +enum CUfilter_mode {}; + + +typedef hipDeviceptr_t CUdeviceptr; +struct CUlinkState_st; + + +typedef hipCtx_t CUcontext; +typedef hipEvent_t CUevent; +typedef hipEvent_t cudaEvent_t; +typedef hipFunction_t CUfunction; +typedef hipFunction_attribute CUfunction_attribute; +typedef hipModule_t CUmodule; +typedef hipStream_t CUstream; +typedef hipStream_t cudaStream_t; +#if HIP_VERSION >= 40300000 +typedef hipGraph_t cudaGraph_t; +typedef hipGraphNode_t cudaGraphNode_t; +typedef hipGraphExec_t cudaGraphExec_t; +#else +typedef void* cudaGraph_t; +typedef void* cudaGraphNode_t; +typedef void* cudaGraphExec_t; +#endif +typedef struct CUlinkState_st* CUlinkState; +typedef struct CUarray_st* CUarray; +struct CUDA_ARRAY_DESCRIPTOR { + CUarray_format Format; + size_t Height; + unsigned int NumChannels; + size_t Width; +}; + + +/////////////////////////////////////////////////////////////////////////////// +// cuda_runtime.h +/////////////////////////////////////////////////////////////////////////////// + +enum { + cudaDevAttrComputeCapabilityMajor + = hipDeviceAttributeComputeCapabilityMajor, + cudaDevAttrComputeCapabilityMinor + = hipDeviceAttributeComputeCapabilityMinor, +}; + +typedef hipError_t cudaError_t; +const CUresult cudaSuccess = static_cast(0); +const CUresult cudaErrorInvalidValue = hipErrorInvalidValue; +const CUresult cudaErrorMemoryAllocation = hipErrorMemoryAllocation; +const CUresult cudaErrorInvalidResourceHandle = hipErrorInvalidResourceHandle; +const CUresult cudaErrorContextIsDestroyed = hipErrorUnknown; // no counterpart in HIP +const CUresult cudaErrorPeerAccessAlreadyEnabled = hipErrorPeerAccessAlreadyEnabled; +typedef enum {} cudaDataType; +typedef hipDeviceAttribute_t cudaDeviceAttr; +typedef hipLimit_t cudaLimit; +typedef hipMemoryAdvise cudaMemoryAdvise; +typedef hipMemcpyKind cudaMemcpyKind; +typedef hipDeviceProp_t cudaDeviceProp; +typedef void* cudaMemPool_t; +enum cudaMemPoolAttr {}; + +typedef hipStreamCallback_t cudaStreamCallback_t; +typedef void (*cudaHostFn_t)(void* userData); +typedef hipPointerAttribute_t cudaPointerAttributes; + +typedef hipChannelFormatKind cudaChannelFormatKind; +typedef hipTextureObject_t cudaTextureObject_t; +typedef hipSurfaceObject_t cudaSurfaceObject_t; +typedef hipResourceType cudaResourceType; +typedef hipTextureAddressMode cudaTextureAddressMode; +typedef hipTextureFilterMode cudaTextureFilterMode; +typedef hipTextureReadMode cudaTextureReadMode; +typedef hipResourceViewDesc cudaResourceViewDesc; +typedef hipArray_t cudaArray_t; +typedef hipExtent cudaExtent; +typedef hipPos cudaPos; +typedef hipPitchedPtr cudaPitchedPtr; +typedef hipMipmappedArray_t cudaMipmappedArray_t; +typedef hipMemcpy3DParms cudaMemcpy3DParms; +typedef hipChannelFormatDesc cudaChannelFormatDesc; +typedef hipResourceDesc cudaResourceDesc; +typedef hipTextureDesc cudaTextureDesc; + +// IPC operations +typedef hipIpcMemHandle_st cudaIpcMemHandle_t; +typedef hipIpcEventHandle_st cudaIpcEventHandle_t; + + +/////////////////////////////////////////////////////////////////////////////// +// blas & lapack (hipBLAS/rocBLAS & rocSOLVER) +/////////////////////////////////////////////////////////////////////////////// + +/* As of ROCm 3.5.0 (this may have started earlier) many rocSOLVER helper functions + * are deprecated and using their counterparts from rocBLAS is recommended. In + * particular, rocSOLVER simply uses rocBLAS's handle for its API calls. This means + * they are much more integrated than cuBLAS and cuSOLVER do, so it is better to + * put all of the relevant function in one place. + */ + +// TODO(leofang): investigate if we should just remove the hipBLAS layer and use +// rocBLAS directly, since we need to expose its handle anyway + + +typedef hipblasHandle_t cublasHandle_t; + +typedef hipblasDiagType_t cublasDiagType_t; +typedef hipblasFillMode_t cublasFillMode_t; +typedef hipblasOperation_t cublasOperation_t; +typedef hipblasPointerMode_t cublasPointerMode_t; +typedef hipblasSideMode_t cublasSideMode_t; +typedef enum {} cublasGemmAlgo_t; +typedef enum {} cublasMath_t; +typedef int cudaDataType_t; +typedef hipblasStatus_t cublasStatus_t; + +// TODO(leofang): as of ROCm 3.5.0 this does not exist yet +typedef enum {} cublasComputeType_t; + +typedef rocblas_status cusolverStatus_t; +typedef rocblas_handle cusolverDnHandle_t; + + +/////////////////////////////////////////////////////////////////////////////// +// library_types.h +// (needed for supporting cusolver) +/////////////////////////////////////////////////////////////////////////////// + +typedef enum libraryPropertyType_t { + MAJOR_VERSION, + MINOR_VERSION, + PATCH_LEVEL +} libraryPropertyType; + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_HIP_CUPY_COMMON_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/hip/cupy_hipblas.h b/src/serinv/cupyfix_backends/hip/cupy_hipblas.h new file mode 100644 index 00000000..f48366c2 --- /dev/null +++ b/src/serinv/cupyfix_backends/hip/cupy_hipblas.h @@ -0,0 +1,89 @@ +#ifndef INCLUDE_GUARD_HIP_CUPY_HIPBLAS_H +#define INCLUDE_GUARD_HIP_CUPY_HIPBLAS_H + +#include "cupy_hip_common.h" +#include +#include // for HIP_VERSION +#include // for gcc 10 + + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// blas & lapack (hipBLAS/rocBLAS & rocSOLVER) +/////////////////////////////////////////////////////////////////////////////// + +/* As of ROCm 3.5.0 (this may have started earlier) many rocSOLVER helper functions + * are deprecated and using their counterparts from rocBLAS is recommended. In + * particular, rocSOLVER simply uses rocBLAS's handle for its API calls. This means + * they are much more integrated than cuBLAS and cuSOLVER are, so it is better to + * put all of the relevant function in one place. + */ + +// TODO(leofang): investigate if we should just remove the hipBLAS layer and use +// rocBLAS directly, since we need to expose its handle anyway + + +/* ---------- helpers ---------- */ +static hipblasOperation_t convert_hipblasOperation_t(cublasOperation_t op) { + return static_cast(static_cast(op) + 111); +} + +static hipblasFillMode_t convert_hipblasFillMode_t(cublasFillMode_t mode) { + switch(static_cast(mode)) { + case 0 /* CUBLAS_FILL_MODE_LOWER */: return HIPBLAS_FILL_MODE_LOWER; + case 1 /* CUBLAS_FILL_MODE_UPPER */: return HIPBLAS_FILL_MODE_UPPER; + default: throw std::runtime_error("unrecognized mode"); + } +} + +static hipblasDiagType_t convert_hipblasDiagType_t(cublasDiagType_t type) { + return static_cast(static_cast(type) + 131); +} + +static hipblasSideMode_t convert_hipblasSideMode_t(cublasSideMode_t mode) { + return static_cast(static_cast(mode) + 141); +} + +static hipblasDatatype_t convert_hipblasDatatype_t(cudaDataType_t type) { + switch(static_cast(type)) { + case 0 /* CUDA_R_32F */: return HIPBLAS_R_32F; + case 1 /* CUDA_R_64F */: return HIPBLAS_R_64F; + case 2 /* CUDA_R_16F */: return HIPBLAS_R_16F; + case 3 /* CUDA_R_8I */ : return HIPBLAS_R_8I; + case 4 /* CUDA_C_32F */: return HIPBLAS_C_32F; + case 5 /* CUDA_C_64F */: return HIPBLAS_C_64F; + case 6 /* CUDA_C_16F */: return HIPBLAS_C_16F; + case 7 /* CUDA_C_8I */ : return HIPBLAS_C_8I; + case 8 /* CUDA_R_8U */ : return HIPBLAS_R_8U; + case 9 /* CUDA_C_8U */ : return HIPBLAS_C_8U; + default: throw std::runtime_error("unrecognized type"); + } +} + +// BLAS Level 3 +cublasStatus_t cublasCherk(cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, int n, int k, + const cuComplex* alpha, const cuComplex* A,int lda, + const cuComplex* beta, cuComplex* C, int ldc) +{ + return hipblasCherk(handle, convert_hipblasFillMode_t(uplo), convert_hipblasOperation_t(trans), n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); +} + +cublasStatus_t cublasZherk(cublasHandle_t handle, cublasFillMode_t uplo, cublasOperation_t trans, int n, int k, + const cuDoubleComplex* alpha, const cuDoubleComplex* A, int lda, + const cuDoubleComplex* beta, cuDoubleComplex* C, int ldc) +{ + return hipblasZherk(handle, convert_hipblasFillMode_t(uplo), convert_hipblasOperation_t(trans), n, k, + reinterpret_cast(alpha), + reinterpret_cast(A), lda, + reinterpret_cast(beta), + reinterpret_cast(C), ldc); +} + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_HIP_CUPY_HIPBLAS_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/stub/cupy_cuComplex.h b/src/serinv/cupyfix_backends/stub/cupy_cuComplex.h new file mode 100644 index 00000000..4e67db36 --- /dev/null +++ b/src/serinv/cupyfix_backends/stub/cupy_cuComplex.h @@ -0,0 +1,22 @@ +// This file is a stub header file of cuda for Read the Docs. + +#ifndef INCLUDE_GUARD_STUB_CUPY_COMPLEX_H +#define INCLUDE_GUARD_STUB_CUPY_COMPLEX_H + +extern "C" { + +/////////////////////////////////////////////////////////////////////////////// +// cuComplex.h +/////////////////////////////////////////////////////////////////////////////// + +struct cuComplex{ + float x, y; +}; + +struct cuDoubleComplex{ + double x, y; +}; + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_STUB_CUPY_COMPLEX_H \ No newline at end of file diff --git a/src/serinv/cupyfix_backends/stub/cupy_cublas.h b/src/serinv/cupyfix_backends/stub/cupy_cublas.h new file mode 100644 index 00000000..d1831837 --- /dev/null +++ b/src/serinv/cupyfix_backends/stub/cupy_cublas.h @@ -0,0 +1,21 @@ +// This file is a stub header file of cuda for Read the Docs. + +#ifndef INCLUDE_GUARD_STUB_CUPY_CUBLAS_H +#define INCLUDE_GUARD_STUB_CUPY_CUBLAS_H + +#include "cupy_cuda_common.h" + +extern "C" { + +cublasStatus_t cublasCsyrk(...) { + return CUBLAS_STATUS_SUCCESS; +} + +cublasStatus_t cublasZsyrk(...) { + return CUBLAS_STATUS_SUCCESS; +} + + +} // extern "C" + +#endif // #ifndef INCLUDE_GUARD_STUB_CUPY_CUBLAS_H \ No newline at end of file diff --git a/src/serinv/utils/__init__.py b/src/serinv/utils/__init__.py index 4079f81c..1c54d228 100644 --- a/src/serinv/utils/__init__.py +++ b/src/serinv/utils/__init__.py @@ -8,6 +8,8 @@ from serinv.utils.pobtx import allocate_pobtx_permutation_buffers from serinv.utils.pobtax import allocate_pobtax_permutation_buffers + + __all__ = [ "check_block_dd", "check_ddbta", diff --git a/src/serinv/wrappers/ddbtars.py b/src/serinv/wrappers/ddbtars.py index 46be5972..722582a6 100644 --- a/src/serinv/wrappers/ddbtars.py +++ b/src/serinv/wrappers/ddbtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_ddbtars( A_diagonal_blocks: ArrayLike, @@ -30,7 +27,14 @@ def allocate_ddbtars( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -80,7 +84,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -144,7 +148,7 @@ def allocate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -224,8 +228,15 @@ def map_ddbtasc_to_ddbtars( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str, + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +371,14 @@ def aggregate_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -455,7 +473,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -564,11 +582,12 @@ def aggregate_ddbtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -576,9 +595,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -586,9 +605,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -596,9 +615,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -606,9 +625,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_arrow_blocks_comm.data.ptr, count=count, @@ -616,17 +635,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -662,11 +682,12 @@ def aggregate_ddbtars( ddbtars["A_upper_arrow_blocks"] = _A_upper_arrow_blocks[1:] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -674,9 +695,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -684,9 +705,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -694,9 +715,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_arrow_blocks_comm.data.ptr, count=count, @@ -704,9 +725,9 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_arrow_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_arrow_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_arrow_blocks_comm.data.ptr, count=count, @@ -714,17 +735,18 @@ def aggregate_ddbtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_B_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_arrow_tip_block_comm.data.ptr, recvbuf=_B_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -763,7 +785,7 @@ def aggregate_ddbtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -787,8 +809,18 @@ def scatter_ddbtars( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() + _A_diagonal_blocks: ArrayLike = ddbtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = ddbtars.get("A_lower_diagonal_blocks", None) _A_upper_diagonal_blocks: ArrayLike = ddbtars.get("A_upper_diagonal_blocks", None) @@ -857,8 +889,15 @@ def map_ddbtars_to_ddbtasci( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ddbtrs.py b/src/serinv/wrappers/ddbtrs.py index 15393ae9..937d8def 100644 --- a/src/serinv/wrappers/ddbtrs.py +++ b/src/serinv/wrappers/ddbtrs.py @@ -24,7 +24,14 @@ def allocate_ddbtrs( comm: MPI.Comm, strategy: str = "allgather", quadratic: bool = False, + nccl_comm: object = None, ) -> dict: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -57,7 +64,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -98,7 +105,7 @@ def allocate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # In this case we also need to allocate a pinned-memory # reduced system on the host side. @@ -150,8 +157,15 @@ def map_ddbtsc_to_ddbtrs( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -241,7 +255,14 @@ def aggregate_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -304,7 +325,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. if comm_rank == 0: @@ -379,11 +400,12 @@ def aggregate_ddbtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -391,9 +413,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -401,9 +423,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -411,6 +433,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -433,11 +456,12 @@ def aggregate_ddbtrs( ddbtrs["A_upper_diagonal_blocks"] = _A_upper_diagonal_blocks[1:-2] if quadratic: - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_diagonal_blocks_comm.data.ptr, count=count, @@ -445,9 +469,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -455,9 +479,9 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_upper_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_B_upper_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_upper_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_B_upper_diagonal_blocks_comm.data.ptr, count=count, @@ -465,6 +489,7 @@ def aggregate_ddbtrs( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_diagonal_blocks_comm, @@ -492,7 +517,7 @@ def aggregate_ddbtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -510,8 +535,15 @@ def scatter_ddbtrs( comm: MPI.Comm, quadratic: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -565,8 +597,15 @@ def map_ddbtrs_to_ddbtsci( _A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pddbtasc.py b/src/serinv/wrappers/pddbtasc.py index da2fcff4..27335195 100644 --- a/src/serinv/wrappers/pddbtasc.py +++ b/src/serinv/wrappers/pddbtasc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -22,6 +24,7 @@ def pddbtasc( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -115,9 +118,10 @@ def pddbtasc( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) - buffers: dict = kwargs.get("buffers", None) ddbtars: dict = kwargs.get("ddbtars", None) strategy: str = kwargs.get("strategy", "allgather") @@ -200,14 +204,23 @@ def pddbtasc( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_ddbtars( ddbtars=ddbtars, quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic ddbtars["A_arrow_tip_block"][:] += A_arrow_tip_initial if quadratic: @@ -226,3 +239,5 @@ def pddbtasc( ) comm.Barrier() + + return elapsed diff --git a/src/serinv/wrappers/pddbtasci.py b/src/serinv/wrappers/pddbtasci.py index f86b6a9c..f235d76b 100644 --- a/src/serinv/wrappers/pddbtasci.py +++ b/src/serinv/wrappers/pddbtasci.py @@ -22,6 +22,7 @@ def pddbtasci( A_upper_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -159,6 +160,7 @@ def pddbtasci( ddbtars=ddbtars, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtars_to_ddbtasci( @@ -179,6 +181,7 @@ def pddbtasci( quadratic=quadratic, buffers=buffers, _rhs=ddbtars.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pddbtsc.py b/src/serinv/wrappers/pddbtsc.py index fc0a3765..e9f1eb9e 100644 --- a/src/serinv/wrappers/pddbtsc.py +++ b/src/serinv/wrappers/pddbtsc.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -19,6 +21,7 @@ def pddbtsc( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel Schur-complement of a block tridiagonal matrix. @@ -82,6 +85,7 @@ def pddbtsc( - _B_upper_diagonal_blocks : ArrayLike The upper diagonal blocks of the reduced system. """ + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -161,14 +165,23 @@ def pddbtsc( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_ddbtrs( ddbtrs=ddbtrs, quadratic=quadratic, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Perform Schur complement on the reduced system ddbtsc( @@ -180,3 +193,5 @@ def pddbtsc( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/pddbtsci.py b/src/serinv/wrappers/pddbtsci.py index 144054f8..63d94c84 100644 --- a/src/serinv/wrappers/pddbtsci.py +++ b/src/serinv/wrappers/pddbtsci.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import ddbtsci @@ -19,6 +18,7 @@ def pddbtsci( A_lower_diagonal_blocks: ArrayLike, A_upper_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel selected-inversion of the Schur-complement of a block tridiagonal matrix. @@ -89,8 +89,6 @@ def pddbtsci( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") - xp, _ = _get_module_from_array(arr=A_diagonal_blocks) - rhs: dict = kwargs.get("rhs", None) quadratic: bool = kwargs.get("quadratic", False) buffers: dict = kwargs.get("buffers", None) @@ -125,6 +123,7 @@ def pddbtsci( ddbtrs=ddbtrs, comm=comm, quadratic=quadratic, + nccl_comm=nccl_comm, ) map_ddbtrs_to_ddbtsci( @@ -139,6 +138,7 @@ def pddbtsci( quadratic=quadratic, buffers=buffers, _rhs=ddbtrs.get("_rhs", None), + nccl_comm=nccl_comm, ) # Perform distributed SCI diff --git a/src/serinv/wrappers/pobtars.py b/src/serinv/wrappers/pobtars.py index 98cfdf46..ee6f26bb 100644 --- a/src/serinv/wrappers/pobtars.py +++ b/src/serinv/wrappers/pobtars.py @@ -15,9 +15,6 @@ import cupyx as cpx import cupy as cp - if backend_flags["nccl_avail"]: - from cupy.cuda import nccl - def allocate_pobtars( A_diagonal_blocks: ArrayLike, @@ -29,6 +26,7 @@ def allocate_pobtars( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PPOBTARX algorithms. @@ -56,6 +54,12 @@ def allocate_pobtars( pobtars : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -106,7 +110,7 @@ def allocate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -152,6 +156,7 @@ def map_ppobtax_to_pobtars( buffer: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the the boundary blocks of the PPOBTAX algorithm to the reduced system. @@ -178,6 +183,12 @@ def map_ppobtax_to_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -228,9 +239,16 @@ def map_ppobtas_to_pobtarss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -252,6 +270,7 @@ def aggregate_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -269,7 +288,14 @@ def aggregate_pobtars( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -306,7 +332,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -317,11 +343,13 @@ def aggregate_pobtars( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -329,9 +357,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, @@ -339,9 +367,9 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_arrow_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_arrow_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_arrow_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_arrow_blocks_comm.data.ptr, count=count, @@ -349,17 +377,20 @@ def aggregate_pobtars( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_arrow_tip_block_comm, comm=comm, op="allreduce" + arr=_A_arrow_tip_block_comm, comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_A_arrow_tip_block_comm.data.ptr, recvbuf=_A_arrow_tip_block_comm.data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -374,7 +405,7 @@ def aggregate_pobtars( ) comm.Allreduce(MPI.IN_PLACE, _A_arrow_tip_block_comm, op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -432,7 +463,7 @@ def aggregate_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -449,8 +480,15 @@ def aggregate_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -474,7 +512,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -484,10 +522,11 @@ def aggregate_pobtarss( if strategy == "allgather": if _use_nccl(comm): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[:-a], comm=comm, op="allgather" + arr=_B_comm[:-a], comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm[:-a].data.ptr + displacement, recvbuf=_B_comm[:-a].data.ptr, count=count, @@ -495,24 +534,25 @@ def aggregate_pobtarss( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm[-a:], comm=comm, op="allreduce" + arr=_B_comm[-a:], comm=communicator, rank=comm_rank, op="allreduce" ) - comm.allReduce( + communicator.allReduce( sendbuf=_B_comm[-a:].data.ptr, recvbuf=_B_comm[-a:].data.ptr, count=count, datatype=datatype, - op=nccl.NCCL_SUM, + op=cp.cuda.nccl.NCCL_SUM, stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm[:-a], ) comm.Allreduce(MPI.IN_PLACE, _B_comm[-a:], op=MPI.SUM) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -547,7 +587,7 @@ def aggregate_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -559,10 +599,18 @@ def scatter_pobtars( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() _A_diagonal_blocks: ArrayLike = pobtars.get("A_diagonal_blocks", None) _A_lower_diagonal_blocks: ArrayLike = pobtars.get("A_lower_diagonal_blocks", None) @@ -598,6 +646,11 @@ def scatter_pobtars( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -607,7 +660,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -652,7 +705,7 @@ def scatter_pobtars( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -671,9 +724,17 @@ def scatter_pobtarss( pobtars: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() + comm_size = comm.Get_size() b = A_diagonal_blocks[0].shape[0] a = A_arrow_tip_block.shape[0] @@ -695,6 +756,11 @@ def scatter_pobtarss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -703,7 +769,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -727,7 +793,7 @@ def scatter_pobtarss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -747,6 +813,7 @@ def map_pobtars_to_ppobtax( _A_arrow_tip_block: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -772,6 +839,12 @@ def map_pobtars_to_ppobtax( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -824,9 +897,16 @@ def map_pobtarss_to_ppobtas( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTAS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/pobtrs.py b/src/serinv/wrappers/pobtrs.py index c716317a..391953d4 100644 --- a/src/serinv/wrappers/pobtrs.py +++ b/src/serinv/wrappers/pobtrs.py @@ -24,6 +24,7 @@ def allocate_pobtrs( B: ArrayLike = None, device_streaming: bool = False, strategy: str = "allgather", + nccl_comm: object = None, ): """Allocate the buffers necessary for the reduced system of the PpobtRX algorithms. @@ -47,6 +48,12 @@ def allocate_pobtrs( pobtrs : dict Dictionary containing the reduced system arrays. """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -92,7 +99,7 @@ def allocate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): _A_diagonal_blocks_comm = cpx.empty_like_pinned(_A_diagonal_blocks) _A_lower_diagonal_blocks_comm = cpx.empty_like_pinned(_A_lower_diagonal_blocks) @@ -126,6 +133,7 @@ def map_ppobtx_to_pobtrs( comm: MPI.Comm, buffer: ArrayLike, strategy: str = "allgather", + nccl_comm: object = None, ) -> None: """Map the the boundary blocks of the PpobtX algorithm to the reduced system. @@ -144,6 +152,12 @@ def map_ppobtx_to_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -183,9 +197,16 @@ def map_ppobts_to_pobtrss( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -204,6 +225,7 @@ def aggregate_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Aggregate the reduced system. @@ -215,6 +237,12 @@ def aggregate_pobtrs( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -241,7 +269,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the HOST pinned arrays. _A_diagonal_blocks.get(out=_A_diagonal_blocks_comm) @@ -250,11 +278,13 @@ def aggregate_pobtrs( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- + cp.cuda.runtime.deviceSynchronize() count, displacement, datatype = _get_nccl_parameters( - arr=_A_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_diagonal_blocks_comm.data.ptr, count=count, @@ -262,16 +292,19 @@ def aggregate_pobtrs( stream=cp.cuda.Stream.null.ptr, ) count, displacement, datatype = _get_nccl_parameters( - arr=_A_lower_diagonal_blocks_comm, comm=comm, op="allgather" + arr=_A_lower_diagonal_blocks_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_A_lower_diagonal_blocks_comm.data.ptr + displacement, recvbuf=_A_lower_diagonal_blocks_comm.data.ptr, count=count, datatype=datatype, stream=cp.cuda.Stream.null.ptr, ) + cp.cuda.runtime.deviceSynchronize() + comm.Barrier() else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _A_diagonal_blocks_comm, @@ -281,7 +314,7 @@ def aggregate_pobtrs( _A_lower_diagonal_blocks_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -322,7 +355,7 @@ def aggregate_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -336,8 +369,15 @@ def aggregate_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -360,7 +400,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # We need to move the data of the reduced system from the GPU to the # HOST pinned arrays. @@ -369,11 +409,12 @@ def aggregate_pobtrss( cp.cuda.runtime.deviceSynchronize() if strategy == "allgather": - if _use_nccl(comm): + if _use_nccl(communicator): + # --- Use NCCL --- count, displacement, datatype = _get_nccl_parameters( - arr=_B_comm, comm=comm, op="allgather" + arr=_B_comm, comm=communicator, rank=comm_rank, op="allgather" ) - comm.allGather( + communicator.allGather( sendbuf=_B_comm.data.ptr + displacement, recvbuf=_B_comm.data.ptr, count=count, @@ -381,12 +422,13 @@ def aggregate_pobtrss( stream=cp.cuda.Stream.null.ptr, ) else: + # --- Use MPI --- comm.Allgather( MPI.IN_PLACE, _B_comm, ) elif strategy == "gather-scatter": - if _use_nccl(comm): + if _use_nccl(communicator): raise ValueError( "NCCL is not supported for gather-scatter communication strategy." ) @@ -415,7 +457,7 @@ def aggregate_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system RHS on the GPU _B.set(arr=_B_comm) @@ -427,9 +469,16 @@ def scatter_pobtrs( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Scatter the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() _A_diagonal_blocks: ArrayLike = pobtrs.get("A_diagonal_blocks", None) @@ -456,6 +505,11 @@ def scatter_pobtrs( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -465,7 +519,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -496,7 +550,7 @@ def scatter_pobtrs( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _A_diagonal_blocks.set(arr=_A_diagonal_blocks_comm) @@ -513,8 +567,15 @@ def scatter_pobtrss( pobtrs: dict, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() b = A_diagonal_blocks[0].shape[0] @@ -536,6 +597,11 @@ def scatter_pobtrss( if strategy == "allgather": ... elif strategy == "gather-scatter": + if _use_nccl(communicator): + raise ValueError( + "NCCL is not supported for gather-scatter communication strategy." + ) + root = kwargs.get("root", None) if root is None: raise ValueError( @@ -544,7 +610,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): if comm_rank == root: # If cupy array, need to move the data to host before initiating the communications @@ -564,7 +630,7 @@ def scatter_pobtrss( if ( xp.__name__ == "cupy" and not backend_flags["mpi_cuda_aware"] - and not _use_nccl(comm) + and not _use_nccl(communicator) ): # Need to put back the reduced system on the GPU _B.set(arr=_B_comm) @@ -580,6 +646,7 @@ def map_pobtrs_to_ppobtx( _A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, **kwargs, ): """Map the reduced system back to the original system. @@ -597,6 +664,12 @@ def map_pobtrs_to_ppobtx( strategy : str, optional Communication strategy to use. (default: "allgather") """ + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() @@ -638,9 +711,16 @@ def map_pobtrss_to_ppobts( _B: ArrayLike, comm: MPI.Comm, strategy: str = "allgather", + nccl_comm: object = None, ): """Map the right-hand side of the PPOBTS algorithm to the right-hand-side of the reduced system.""" + communicator = None + if nccl_comm is not None: + communicator = nccl_comm + else: + communicator = comm + comm_rank = comm.Get_rank() comm_size = comm.Get_size() diff --git a/src/serinv/wrappers/ppobtaf.py b/src/serinv/wrappers/ppobtaf.py index 64946a3d..2122b88a 100644 --- a/src/serinv/wrappers/ppobtaf.py +++ b/src/serinv/wrappers/ppobtaf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -20,6 +22,7 @@ def ppobtaf( A_lower_arrow_blocks: ArrayLike, A_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -62,6 +65,8 @@ def ppobtaf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -133,14 +138,23 @@ def ppobtaf( buffer=buffer, strategy=strategy, comm=comm, + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_pobtars( pobtars=pobtars, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- pobtars["A_arrow_tip_block"][:] += A_arrow_tip_initial @@ -165,3 +179,5 @@ def ppobtaf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtas.py b/src/serinv/wrappers/ppobtas.py index 132a9d08..52896530 100644 --- a/src/serinv/wrappers/ppobtas.py +++ b/src/serinv/wrappers/ppobtas.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -23,6 +25,7 @@ def ppobtas( L_arrow_tip_block: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -138,16 +141,25 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Agregate reduced RHS + comm.Barrier() + tic = time.perf_counter() aggregate_pobtarss( A_diagonal_blocks=L_diagonal_blocks, A_arrow_tip_block=L_arrow_tip_block, pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Add the tip block of the RHS to the aggregated update _B[-a:] += B_tip_initial @@ -180,6 +192,7 @@ def ppobtas( pobtars=pobtars, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -190,6 +203,7 @@ def ppobtas( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -213,3 +227,5 @@ def ppobtas( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtasi.py b/src/serinv/wrappers/ppobtasi.py index 12e23e14..96b67c20 100644 --- a/src/serinv/wrappers/ppobtasi.py +++ b/src/serinv/wrappers/ppobtasi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtasi @@ -20,6 +19,7 @@ def ppobtasi( L_lower_arrow_blocks: ArrayLike, L_arrow_tip_block: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -114,6 +114,7 @@ def ppobtasi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -129,6 +130,7 @@ def ppobtasi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system diff --git a/src/serinv/wrappers/ppobtf.py b/src/serinv/wrappers/ppobtf.py index 2d680c96..f1b44956 100644 --- a/src/serinv/wrappers/ppobtf.py +++ b/src/serinv/wrappers/ppobtf.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -18,6 +20,7 @@ def ppobtf( A_diagonal_blocks: ArrayLike, A_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ) -> ArrayLike: """Perform the parallel factorization of a block tridiagonal with arrowhead matrix @@ -56,6 +59,8 @@ def ppobtf( if comm_size == 1: raise ValueError("The number of MPI processes must be greater than 1.") + xp, _ = _get_module_from_array(arr=A_diagonal_blocks) + # Check for optional parameters device_streaming: bool = kwargs.get("device_streaming", False) strategy: str = kwargs.get("strategy", "allgather") @@ -106,14 +111,23 @@ def ppobtf( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) + comm.Barrier() + tic = time.perf_counter() aggregate_pobtrs( pobtrs=pobtrs, comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # --- Factorize the reduced system --- if strategy == "gather-scatter": @@ -134,3 +148,5 @@ def ppobtf( ) comm.Barrier() + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobts.py b/src/serinv/wrappers/ppobts.py index 396577f1..86906883 100644 --- a/src/serinv/wrappers/ppobts.py +++ b/src/serinv/wrappers/ppobts.py @@ -1,5 +1,7 @@ # Copyright 2023-2025 ETH Zurich. All rights reserved. +import time + from mpi4py import MPI from serinv import ( @@ -21,6 +23,7 @@ def ppobts( L_lower_diagonal_blocks: ArrayLike, B: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -119,14 +122,24 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + # Agregate reduced RHS + comm.Barrier() + tic = time.perf_counter() aggregate_pobtrss( A_diagonal_blocks=L_diagonal_blocks, pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) + if xp.__name__ == "cupy": + xp.cuda.runtime.deviceSynchronize() + comm.Barrier() + toc = time.perf_counter() + elapsed = toc - tic # Solve RHS FWD/BWD if strategy == "allgather": @@ -151,6 +164,7 @@ def ppobts( pobtrs=pobtrs, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Map solution of reduced RHS to RHS @@ -160,6 +174,7 @@ def ppobts( _B=_B, comm=comm, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel backward solve @@ -179,3 +194,5 @@ def ppobts( buffer=buffer, trans="C", ) + + return elapsed \ No newline at end of file diff --git a/src/serinv/wrappers/ppobtsi.py b/src/serinv/wrappers/ppobtsi.py index fde320fc..72c7354d 100644 --- a/src/serinv/wrappers/ppobtsi.py +++ b/src/serinv/wrappers/ppobtsi.py @@ -4,7 +4,6 @@ from serinv import ( ArrayLike, - _get_module_from_array, ) from serinv.algs import pobtsi @@ -18,6 +17,7 @@ def ppobtsi( L_diagonal_blocks: ArrayLike, L_lower_diagonal_blocks: ArrayLike, comm: MPI.Comm = MPI.COMM_WORLD, + nccl_comm: object = None, **kwargs, ): """Perform a selected inversion of a block tridiagonal with arrowhead matrix (pointing downward by convention). @@ -100,6 +100,7 @@ def ppobtsi( comm=comm, strategy=strategy, root=root, + nccl_comm=nccl_comm, ) # Map result of the reduced system back to the original system @@ -111,6 +112,7 @@ def ppobtsi( comm=comm, buffer=buffer, strategy=strategy, + nccl_comm=nccl_comm, ) # Parallel selected inversion of the original system